init: 整合基础框架,fastapi+sqlmodel(增强逻辑删除),搭建部分基础环境

main
chenzhirong 5 months ago
commit e25b78f77b

@ -0,0 +1,3 @@
SERVICE_CONF=application.yaml
DEBUG=true
DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name

173
.gitignore vendored

@ -0,0 +1,173 @@
# volumes
docker/volumes/**
docker/nginx/dist
web/dist
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
coverage.json
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.env-local
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.conda/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/'
.DS_Store
web/.vscode/settings.json
# Intellij IDEA Files
.idea/*
!.idea/vcs.xml
!.idea/icon.png
.ideaDataSources/
*.iml
.vscode/*
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
.idea/
# pnpm
/.pnpm-store
# plugin migrate
plugins.jsonl
# mise
mise.toml
# Next.js build output
.next/
# AI Assistant
.roo/
api/.env.backup
application.yaml

@ -0,0 +1,3 @@
zhipu:
api-key:
base-url:

@ -0,0 +1,9 @@
class MetaConst(type):
def __setattr__(cls, name, value):
if name in cls.__dict__:
raise TypeError(f"Cannot rebind constant {name}")
super().__setattr__(name, value)
class Constant(metaclass=MetaConst):
LOGICAL_DELETE_FIELD = "is_deleted"

@ -0,0 +1,6 @@
from enum import IntEnum
class IsDelete(IntEnum):
NO_DELETE = 0
DELETE = 1

@ -0,0 +1,8 @@
from config.settings import Settings, get_settings
settings = get_settings()
def get_yaml_conf(k):
return settings.yaml_config.get(k, None)
def show_configs():
return settings

@ -0,0 +1,119 @@
import secrets
import sys
from contextlib import asynccontextmanager
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import List
import anyio.to_thread
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from config import settings
__all__ = ["app"]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
current_default_thread_limiter()官方文档里的唯一合法入口
total_tokens不是线程数而是并发通行证令牌用光就排队
官方建议不要盲目拉满CPU核心数*5是一个经验上限
"""
# 1. 拿到当前全局限速器
limiter = anyio.to_thread.current_default_thread_limiter()
# 2. 把40个线程改成80
limiter.total_tokens = 80
yield
# FastAPI应用初始化
app = FastAPI(
title="DataBuilder API",
description="数据工厂 api",
version="1.0.0",
lifespan=lifespan
)
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="数据工厂",
version="1.0.0",
description="数据工厂接口文档",
routes=app.routes,
)
# # 添加全局安全方案
# openapi_schema["components"]["securitySchemes"] = {
# "global_auth": {
# "type": "apiKey",
# "in": "header",
# "name": "Authorization" # 这里可以改为任何需要的请求头名称
# }
# }
# 应用全局安全要求
openapi_schema["security"] = [
{"global_auth": []}
]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.add_middleware(SessionMiddleware, secret_key=secrets.token_hex(32))
# CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=2592000
)
# 全局异常处理
# configure_exception(app)
white_list = ["/docs", "/openapi.json", "/redoc"]
# 动态路由注册
def search_pages_path(pages_dir: Path) -> List[Path]:
return [path for path in pages_dir.glob("*.py") if not path.name.startswith(".") and not path.name.startswith("_")]
def register_controller(page_path: Path, prefix=settings.api_version):
module_name = f"router.{page_path.stem}"
spec = spec_from_file_location(module_name, page_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {page_path}")
page = module_from_spec(spec)
sys.modules[module_name] = page
spec.loader.exec_module(page)
# 注册路由
if hasattr(page, "router"):
app.include_router(page.router, prefix=prefix)
return page.router if page.router else prefix
# 注册所有控制器
pages_dirs = [
Path(__file__).parent.parent / "router",
]
client_urls_prefix = [
register_controller(path)
for pages_dir in pages_dirs
for path in search_pages_path(pages_dir)
]

@ -0,0 +1,57 @@
from functools import lru_cache
from dotenv import load_dotenv
from pydantic.v1 import BaseSettings as Base
from utils import file_utils
class BaseSettings(Base):
"""配置基类"""
class Config:
env_file = '.env'
env_file_encoding = 'utf-8'
extra = 'allow'
class Settings(BaseSettings):
"""应用配置
server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写)
eg.
config.env 文件内写入
REDIS_URL='redis://localhost:6379'
上述环境变量会覆盖 redis_url
"""
# 模式
mode: str = 'dev' # dev, prod
debug: bool = False # dev, prod
service_conf: str = 'application.yaml' # dev, prod
# 版本
api_version: str = '/v1'
# 时区
timezone: str = 'Asia/Shanghai'
# 日期时间格式
datetime_fmt: str = '%Y-%m-%d %H:%M:%S'
# Redis键前缀
redis_prefix: str = 'agent:'
# 当前域名
host_ip:str = '0.0.0.0'
host_port: int = 8080
#sql驱动连接
database_url: str = ''
# yaml配置
yaml_config: dict = {}
@lru_cache()
def get_settings() -> Settings:
"""获取并缓存应用配置"""
# 读取server目录下的配置
load_dotenv()
settings = Settings()
yaml_config = file_utils.load_yaml_conf(settings.service_conf)
# 将YAML配置存储到Settings实例中
settings.yaml_config = yaml_config
for k, v in settings.yaml_config.items():
if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default:
setattr(settings, k, v)
return settings

@ -0,0 +1,6 @@
from contextvars import ContextVar
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None)

@ -0,0 +1,140 @@
import asyncio
import inspect
from typing import Any
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlmodel import SQLModel
from common.constant import Constant
from common.global_enums import IsDelete
from config import settings
from entity.base_entity import DbBaseModel
engine = create_async_engine(
settings.database_url,
echo=True, # 打印SQL日志生产环境建议关闭
pool_size=10, # 连接池大小
max_overflow=20, # 最大溢出连接数
pool_recycle=3600, # 连接回收时间解决MySQL超时断开问题【4†source】【5†source】
)
# 创建异步会话工厂
class EnhanceAsyncSession(AsyncSession):
async def execute(
self,
statement: Executable,
params=None,
*,
execution_options=None,
bind_arguments=None,
**kw: Any,
) -> Result[Any]:
sig = inspect.signature(super().execute)
if execution_options is None:
default_execution_options = sig.parameters['execution_options'].default
execution_options = default_execution_options
if isinstance(statement, Select):
print("这是查询语句,过滤逻辑删除")
delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE
# 获取现有条件
existing_condition = statement.whereclause
# 组合条件
if existing_condition is not None:
# 使用and_组合现有条件和逻辑删除条件
new_condition = and_(existing_condition, delete_condition)
else:
new_condition = delete_condition
# 应用新条件创建新的Select对象
statement = statement.where(new_condition)
if isinstance(statement, Delete):
# 检查是否跳过软删除通过execution_options控制
skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False)
if not skip_soft_delete:
# 获取表对象
table = statement.table
# 构建更新语句
update_stmt = (
Update(table)
.where(statement.whereclause) # 保留原删除条件
.values(**{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}) # 设置软删除标记
)
# 如果原删除语句有RETURNING子句也添加到更新语句中
if statement._returning:
update_stmt = update_stmt.returning(*statement._returning)
# 执行更新语句
return await super().execute(
update_stmt,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
result = await super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
return result
# 重写delete方法处理单个对象删除
def delete(self, instance):
from sqlalchemy import inspect
# 检查是否有逻辑删除属性
if hasattr(instance, Constant.LOGICAL_DELETE_FIELD):
# 设置软删除标记
instance.__setattr__(Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE)
# 确保对象在会话中(如果已分离则重新关联)
# 检查对象状态
insp = inspect(instance)
if insp.detached:
# 如果对象是分离的,则重新加入会话
self.add(instance)
elif insp.transient:
# 如果是瞬态对象,也添加到会话
self.add(instance)
# 标记对象为已修改(触发更新)
# self.expire(instance, [Constant.LOGICAL_DELETE_FIELD])
else:
# 如果没有逻辑删除属性,执行标准删除
super().delete(instance)
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=EnhanceAsyncSession,
expire_on_commit=False, # 提交后不使对象过期
autoflush=False # 禁用自动刷新
)
# 关闭引擎
async def close_engine():
await engine.dispose()
# 初始化数据库表(异步执行)
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
if __name__ == '__main__':
import user
async def main():
try:
await init_db()
finally:
await close_engine() # 确保引擎关闭
asyncio.run(main())

@ -0,0 +1,21 @@
from sqlalchemy import Column, BigInteger
from common.global_enums import IsDelete
from utils import current_timestamp
# 上下文变量,控制是否启用逻辑删除过滤
_SOFT_DELETE_ENABLED = True
from sqlmodel import SQLModel, Field
class DbBaseModel(SQLModel, table=False):
id: str = Field(default=None, max_length=32, primary_key=True)
created_time: int = Field(sa_column=Column(BigInteger), default_factory=current_timestamp)
# created_by = CharField(max_length=32, index=True)
updated_time: int = Field(sa_column=Column(BigInteger, onupdate=current_timestamp),
default_factory=current_timestamp)
# updated_by = CharField(max_length=32)
is_deleted: int = Field(default=IsDelete.NO_DELETE)
# class Config:
# arbitrary_types_allowed = True

@ -0,0 +1,7 @@
from entity.base_entity import DbBaseModel
class User(DbBaseModel,table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str

@ -0,0 +1,70 @@
import logging
import os
import signal
import sys
import threading
import time
import traceback
from config import settings, show_configs
from utils import file_utils
stop_event = threading.Event()
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()
time.sleep(1)
sys.exit(0)
if __name__ == '__main__':
logging.info(r"""
_____ _ ____ _ _ _
| __ \ | | | _ \ (_) | | |
| | | | __ _| |_ __ _| |_) |_ _ _| | __| | ___ _ __
| | | |/ _` | __/ _` | _ <| | | | | |/ _` |/ _ \ '__|
| |__| | (_| | || (_| | |_) | |_| | | | (_| | __/ |
|_____/ \__,_|\__\__,_|____/ \__,_|_|_|\__,_|\___|_|
""")
logging.info(
f'project base: {file_utils.get_project_base_directory()}'
)
show_configs()
# import argparse
# parser = argparse.ArgumentParser()
#
# args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
logging.info("服务启动ing...")
import uvicorn
# 配置Uvicorn参数
uvicorn_config = {
# "app": app, # FastAPI应用实例
"app": "config.fastapi_config:app", # FastAPI应用实例
"host": settings.host_ip,
"port": settings.host_port,
"reload": settings.debug, # 开发模式启用热重载
"log_level": "debug" if settings.debug else "info",
"access_log": True,
}
# 如果是调试模式,添加额外配置
if settings.debug:
uvicorn_config.update({
"reload_dirs": ["."], # 监视当前目录变化
"reload_delay": 0.5, # 重载延迟
})
# 启动Uvicorn服务器
uvicorn.run(**uvicorn_config)
except Exception:
traceback.print_exc()
stop_event.set()
time.sleep(1)
os.kill(os.getpid(), signal.SIGKILL)

@ -0,0 +1,20 @@
from fastapi import Request
from core.global_context import current_session
from entity import AsyncSessionLocal
async def db_session_middleware(request: Request, call_next):
async with AsyncSessionLocal() as session:
# 设置会话到上下文变量
token = current_session.set(session)
try:
response = await call_next(request)
await session.commit()
except Exception:
await session.rollback()
raise
finally:
# 重置上下文变量
current_session.reset(token)
return response

@ -0,0 +1,24 @@
[project]
name = "rg-agno-agent"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"fastapi[standard]>=0.116.1",
"langchain>=0.3.27",
"mcp>=1.14.0",
"openai>=1.107.3",
"sqlalchemy[asyncio]>=2.0.43",
"ruamel-yaml>=0.18.6,<0.19.0", # YAML处理
"cachetools==5.3.3", # 缓存工具
"filelock==3.15.4", # 文件锁
"itsdangerous==2.1.2", # 安全签名,用于 SessionMiddleware
"langchain-openai>=0.3.33",
"langchainhub>=0.1.21",
"httpx-sse>=0.4.1",
"sqlmodel>=0.0.25",
"aiomysql>=0.2.0",
]
[[tool.uv.index]]
url = "https://mirrors.aliyun.com/pypi/simple"

@ -0,0 +1,13 @@
from fastapi import APIRouter
from config import settings
from utils.file_utils import get_project_base_directory
router = APIRouter()
@router.get("/test")
async def test():
return {
"settings":settings,
"base_path": get_project_base_directory(),
}

@ -0,0 +1,461 @@
from typing import Union, Type, List, Any, TypeVar, Generic
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from core.global_context import current_session
from utils import get_uuid, current_timestamp
T = TypeVar('T')
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
"""获取当前请求的会话"""
session = current_session.get()
if session is None:
raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.")
return session
@classmethod
async def create(cls, **kwargs) -> T:
"""通用创建方法"""
obj = cls.model(**kwargs)
db = cls.get_db()
db.add(obj)
await db.flush()
return obj
@classmethod
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[
BaseModel, List[BaseModel]]:
dto_list = []
if not isinstance(entity_data, list):
return dto(**entity_data.to_dict())
for entity in entity_data:
temp = entity
if not isinstance(entity, dict):
temp = entity.to_dict()
dto_list.append(dto(**temp))
return dto_list
@classmethod
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
"""Execute a database query with optional column selection and ordering.
This method provides a flexible way to query the database with various filters
and sorting options. It supports column selection, sort order control, and
additional filter conditions.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by.
**kwargs: Additional filter conditions passed as keyword arguments.
Returns:
peewee.ModelSelect: A query result containing matching records.
"""
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
def get_all(cls, cols=None, reverse=None, order_by=None):
"""Retrieve all records from the database with optional column selection and ordering.
This method fetches all records from the model's table with support for
column selection and result ordering. If no order_by is specified and reverse
is True, it defaults to ordering by created_time.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by. Defaults to 'created_time' if reverse is specified.
Returns:
peewee.ModelSelect: A query containing all matching records.
"""
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "created_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
def get(cls, **kwargs):
"""Get a single record matching the given criteria.
This method retrieves a single record from the database that matches
the specified filter conditions.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance: Single matching record.
Raises:
peewee.DoesNotExist: If no matching record is found.
"""
return cls.model.get(**kwargs)
@classmethod
def get_or_none(cls, **kwargs):
"""Get a single record or None if not found.
This method attempts to retrieve a single record matching the given criteria,
returning None if no match is found instead of raising an exception.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance or None: Matching record if found, None otherwise.
"""
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
sessions = cls.get_query_session(query_params)
return cls.auto_page(sessions, query_params)
@classmethod
def auto_page(cls, sessions, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None):
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
page_number = query_params.get("page_number", 1)
page_size = query_params.get("page_size", 12)
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
data_count = sessions.count()
if data_count == 0:
return BasePageResp(**{
"page_number": page_number,
"page_size": page_size,
"count": data_count,
"desc": desc,
"orderby": orderby,
"data": [],
})
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
sessions = sessions.paginate(int(page_number), int(page_size))
datas = list(sessions.dicts())
result = datas
if dto_model_class is not None:
result = [dto_model_class(**item) for item in datas]
return BasePageResp(**{
"page_number": page_number,
"page_size": page_size,
"count": data_count,
"desc": desc,
"orderby": orderby,
"data": result,
})
@classmethod
def get_list(cls, query_params: Union[dict, BaseQueryReq]):
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
sessions = cls.get_query_session(query_params)
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
return list(sessions.dicts())
@classmethod
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
sessions = cls.model.select(cls.model.id)
sessions = cls.get_query_session(query_params, sessions)
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
return [item["id"] for item in list(sessions.dicts())]
@classmethod
def save(cls, **kwargs):
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj > 0
@classmethod
def insert(cls, **kwargs):
"""Insert a new record with automatic ID and timestamps.
This method creates a new record with automatically generated ID and timestamp fields.
It handles the creation of created_time, create_date, updated_time, and update_date fields.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The newly created record object.
"""
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["created_time"] = current_timestamp()
# kwargs["create_date"] = datetime_format(datetime.now())
kwargs["updated_time"] = current_timestamp()
# kwargs["update_date"] = datetime_format(datetime.now())
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj > 0
@classmethod
def insert_many(cls, data_list, batch_size=100):
"""Insert multiple records in batches.
This method efficiently inserts multiple records into the database using batch processing.
It automatically sets creation timestamps for all records.
Args:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
with DB.atomic():
for d in data_list:
if not d.get("id", None):
d["id"] = get_uuid()
d["created_time"] = current_timestamp()
# d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i: i + batch_size]).execute()
@classmethod
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the updated_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["updated_time"] = current_timestamp()
# data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
def updated_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["updated_time"] = current_timestamp()
# data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num > 0
@classmethod
def get_by_id(cls, pid):
# Get a record by ID
# Args:
# pid: Record ID
# Returns:
# Tuple of (success, record)
try:
obj = cls.model.get_or_none(cls.model.id == pid)
if obj:
return True, obj
except Exception:
pass
return False, None
@classmethod
def get_by_ids(cls, pids, cols=None):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
def get_last_by_create_time(cls):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
latest = cls.model.select().order_by(cls.model.created_time.desc()).first()
return latest
@classmethod
def delete_by_id(cls, pid):
# Delete a record by ID
# Args:
# pid: Record ID
# Returns:
# Number of records deleted
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
def delete_by_ids(cls, pids):
# Delete multiple records by their IDs
# Args:
# pids: List of record IDs
# Returns:
# Number of records deleted
with DB.atomic():
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
return res
@classmethod
def filter_delete(cls, filters):
# Delete records matching given filters
# Args:
# filters: List of filter conditions
# Returns:
# Number of records deleted
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
def filter_update(cls, filters, update_data):
# Update records matching given filters
# Args:
# filters: List of filter conditions
# update_data: Updated field values
# Returns:
# Number of records updated
with DB.atomic():
return cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
# Split a list into chunks of size n
# Args:
# tar_list: List to split
# n: Chunk size
# Returns:
# List of tuples containing chunks
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x: (x + n)]) for x in arr[::n]]
return result
@classmethod
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
# Get records matching IN clause filters with optional column selection
# Args:
# in_key: Field name for IN clause
# in_filters_list: List of values for IN clause
# filters: Additional filter conditions
# cols: List of columns to select
# Returns:
# List of matching records
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list
@classmethod
def get_query_session(cls, query_params, sessions=None):
if sessions is None:
sessions = cls.model.select()
for key, value in query_params.items():
if hasattr(cls.model, key):
field = getattr(cls.model, key)
sessions = sessions.where(field == value)
return sessions
@classmethod
def get_data_count(cls, query_params: dict = None):
if not query_params:
raise Exception("参数为空")
sessions = cls.get_query_session(query_params)
return sessions.count()
@classmethod
def is_exist(cls, query_params: dict = None):
return cls.get_data_count(query_params) > 0
@classmethod
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["updated_time"] = current_timestamp()
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
def check_base_permission(cls, model_data):
if isinstance(model_data, dict):
if model_data.get("created_by") != get_current_user().id:
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
if model_data.created_by != get_current_user().id:
raise RuntimeError("无操作权限,该操作仅创建者有此权限")

@ -0,0 +1,16 @@
from langchain_openai import ChatOpenAI
from config import settings
class ChatService:
llm = ChatOpenAI(
temperature=0.6,
model="glm-4.5",
openai_api_key=settings.yaml_config.get("zhipu").get("api-key"),
openai_api_base=settings.yaml_config.get("zhipu").get("base-url")
)
@classmethod
def chat(cls,message):
resp = cls.llm.invoke(message)
print(resp)

@ -0,0 +1,62 @@
from contextvars import ContextVar
from contextvars import ContextVar
from typing import Optional, TypeVar, Generic, Type
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from entity import AsyncSessionLocal
from entity.user import User
# 1. 创建上下文变量存储当前会话
current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None)
# 3. 中间件:管理请求生命周期和会话
async def db_session_middleware(request: Request, call_next):
async with AsyncSessionLocal() as session:
# 设置会话到上下文变量
token = current_session.set(session)
try:
response = await call_next(request)
await session.commit()
except Exception:
await session.rollback()
raise
finally:
# 重置上下文变量
current_session.reset(token)
return response
# 4. 服务基类
T = TypeVar('T')
class BaseService(Generic[T]):
model: Type[T] = None # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
"""获取当前请求的会话"""
session = current_session.get()
if session is None:
raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.")
return session
@classmethod
async def create(cls, **kwargs) -> T:
"""通用创建方法"""
obj = cls.model(**kwargs)
db = cls.get_db()
db.add(obj)
await db.flush()
return obj
@classmethod
async def get(cls, id: int) -> Optional[T]:
"""通用获取方法"""
db = cls.get_db()
return await db.get(cls.model, id)
# 5. 具体服务类
class UserService(BaseService[User]):
model = User # 指定模型

@ -0,0 +1,66 @@
import asyncio
from sqlalchemy import delete
from sqlmodel import select
from entity import AsyncSessionLocal, close_engine
from entity.user import User
async def add():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
# session = get_db_session()
async with AsyncSessionLocal() as session:
for user in add_list:
session.add(user)
await session.commit()
await close_engine()
print("完成")
async def remove():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
for i in add_list:
delete_sql = delete(User).where(User.id == i.id)
await session.execute(delete_sql)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def remove_one():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
# for i in add_list:
# delete_sql = delete(User).where(User.id == i.id)
user=await session.execute(select(User))
user = user.scalars().first()
session.delete(user)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def get_all():
async with AsyncSessionLocal() as session:
sql = select(User)
result = await session.execute(sql)
print(result.scalars().all())
await session.commit()
await close_engine()
print("====================get_all==================")
def test(aa="ddd"):
print("aa=", aa)
if __name__ == "__main__":
# asyncio.run(add())
asyncio.run(get_all())
# asyncio.run(remove_one())
asyncio.run(remove())
asyncio.run(get_all())
# test(None)

@ -0,0 +1,377 @@
# import base64
# import copy
# import datetime
# import hashlib
# import importlib
# import io
# import json
# import logging
# import os
# import pickle
# import socket
import time
import uuid
# from enum import Enum, IntEnum
#
#
# import requests
#
# from fastapi.encoders import jsonable_encoder
# from filelock import FileLock
#
# from common.constants import SERVICE_CONF
# from . import file_utils
#
def current_timestamp() -> int:
"""获取时间戳"""
return int(time.time() * 1000)
def get_uuid():
return uuid.uuid1().hex
# def conf_realpath(conf_name):
# conf_path = f"{conf_name}"
# return os.path.join(file_utils.get_project_base_directory(), conf_path)
#
#
# # def read_config(conf_name=SERVICE_CONF):
# # local_config = {}
# # local_path = conf_realpath(f'local.{conf_name}')
# #
# # # load local config file
# # if os.path.exists(local_path):
# # local_config = file_utils.load_yaml_conf(local_path)
# # if not isinstance(local_config, dict):
# # raise ValueError(f'Invalid config file: "{local_path}".')
# #
# # global_config_path = conf_realpath(conf_name)
# # global_config = file_utils.load_yaml_conf(global_config_path)
# #
# # if not isinstance(global_config, dict):
# # raise ValueError(f'Invalid config file: "{global_config_path}".')
# #
# # global_config.update(local_config)
# # return global_config
# #
# #
# # CONFIGS = read_config()
# #
# #
# # def show_configs():
# # msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
# # for k, v in CONFIGS.items():
# # if isinstance(v, dict):
# # if "password" in v:
# # v = copy.deepcopy(v)
# # v["password"] = "*" * 8
# # if "access_key" in v:
# # v = copy.deepcopy(v)
# # v["access_key"] = "*" * 8
# # if "secret_key" in v:
# # v = copy.deepcopy(v)
# # v["secret_key"] = "*" * 8
# # if "secret" in v:
# # v = copy.deepcopy(v)
# # v["secret"] = "*" * 8
# # if "sas_token" in v:
# # v = copy.deepcopy(v)
# # v["sas_token"] = "*" * 8
# # if "oauth" in k:
# # v = copy.deepcopy(v)
# # for key, val in v.items():
# # if "client_secret" in val:
# # val["client_secret"] = "*" * 8
# # if "authentication" in k:
# # v = copy.deepcopy(v)
# # for key, val in v.items():
# # if "http_secret_key" in val:
# # val["http_secret_key"] = "*" * 8
# # msg += f"\n\t{k}: {v}"
# # logging.info(msg)
#
#
# # def get_base_config(key, default=None):
# # if key is None:
# # return None
# # if default is None:
# # default = os.environ.get(key.upper())
# # return CONFIGS.get(key, default)
#
#
# use_deserialize_safe_module = get_base_config(
# 'use_deserialize_safe_module', False)
#
#
# class BaseType:
# def to_dict(self):
# return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
#
# def to_dict_with_type(self):
# def _dict(obj):
# module = None
# if issubclass(obj.__class__, BaseType):
# data = {}
# for attr, v in obj.__dict__.items():
# k = attr.lstrip("_")
# data[k] = _dict(v)
# module = obj.__module__
# elif isinstance(obj, (list, tuple)):
# data = []
# for i, vv in enumerate(obj):
# data.append(_dict(vv))
# elif isinstance(obj, dict):
# data = {}
# for _k, vv in obj.items():
# data[_k] = _dict(vv)
# else:
# data = obj
# return {"type": obj.__class__.__name__,
# "data": data, "module": module}
#
# return _dict(self)
#
#
# class CustomJSONEncoder(json.JSONEncoder):
# def __init__(self, **kwargs):
# self._with_type = kwargs.pop("with_type", False)
# super().__init__(**kwargs)
#
# def default(self, obj):
# if isinstance(obj, datetime.datetime):
# return obj.strftime('%Y-%m-%d %H:%M:%S')
# elif isinstance(obj, datetime.date):
# return obj.strftime('%Y-%m-%d')
# elif isinstance(obj, datetime.timedelta):
# return str(obj)
# elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
# return obj.value
# elif isinstance(obj, set):
# return list(obj)
# elif issubclass(type(obj), BaseType):
# if not self._with_type:
# return obj.to_dict()
# else:
# return obj.to_dict_with_type()
# elif isinstance(obj, type):
# return obj.__name__
# else:
# # return json.JSONEncoder.default(self, obj)
# return jsonable_encoder(obj)
#
#
# def rag_uuid():
# return uuid.uuid1().hex
#
#
# def string_to_bytes(string):
# return string if isinstance(
# string, bytes) else string.encode(encoding="utf-8")
#
#
# def bytes_to_string(byte):
# return byte.decode(encoding="utf-8")
#
#
# def json_dumps(src, byte=False, indent=None, with_type=False):
# dest = json.dumps(
# src,
# indent=indent,
# cls=CustomJSONEncoder,
# with_type=with_type)
# if byte:
# dest = string_to_bytes(dest)
# return dest
#
#
# def json_loads(src, object_hook=None, object_pairs_hook=None):
# if isinstance(src, bytes):
# src = bytes_to_string(src)
# return json.loads(src, object_hook=object_hook,
# object_pairs_hook=object_pairs_hook)
#
#
#
#
# def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
# if not timestamp:
# timestamp = time.time()
# timestamp = int(timestamp) / 1000
# time_array = time.localtime(timestamp)
# str_date = time.strftime(format_string, time_array)
# return str_date
#
#
# def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
# time_array = time.strptime(time_str, format_string)
# time_stamp = int(time.mktime(time_array) * 1000)
# return time_stamp
#
#
# def serialize_b64(src, to_str=False):
# dest = base64.b64encode(pickle.dumps(src))
# if not to_str:
# return dest
# else:
# return bytes_to_string(dest)
#
#
# def deserialize_b64(src):
# src = base64.b64decode(
# string_to_bytes(src) if isinstance(
# src, str) else src)
# if use_deserialize_safe_module:
# return restricted_loads(src)
# return pickle.loads(src)
#
#
# safe_module = {
# 'numpy',
# 'rag_flow'
# }
#
#
# class RestrictedUnpickler(pickle.Unpickler):
# def find_class(self, module, name):
# import importlib
# if module.split('.')[0] in safe_module:
# _module = importlib.import_module(module)
# return getattr(_module, name)
# # Forbid everything else.
# raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
# (module, name))
#
#
# def restricted_loads(src):
# """Helper function analogous to pickle.loads()."""
# return RestrictedUnpickler(io.BytesIO(src)).load()
#
#
# def get_lan_ip():
# if os.name != "nt":
# import fcntl
# import struct
#
# def get_interface_ip(ifname):
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# return socket.inet_ntoa(
# fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
#
# ip = socket.gethostbyname(socket.getfqdn())
# if ip.startswith("127.") and os.name != "nt":
# interfaces = [
# "bond1",
# "eth0",
# "eth1",
# "eth2",
# "wlan0",
# "wlan1",
# "wifi0",
# "ath0",
# "ath1",
# "ppp0",
# ]
# for ifname in interfaces:
# try:
# ip = get_interface_ip(ifname)
# break
# except IOError:
# pass
# return ip or ''
#
#
# def from_dict_hook(in_dict: dict):
# if "type" in in_dict and "data" in in_dict:
# if in_dict["module"] is None:
# return in_dict["data"]
# else:
# return getattr(importlib.import_module(
# in_dict["module"]), in_dict["type"])(**in_dict["data"])
# else:
# return in_dict
#
#
# def decrypt_database_password(password):
# encrypt_password = get_base_config("encrypt_password", False)
# encrypt_module = get_base_config("encrypt_module", False)
# private_key = get_base_config("private_key", None)
#
# if not password or not encrypt_password:
# return password
#
# if not private_key:
# raise ValueError("No private key")
#
# module_fun = encrypt_module.split("#")
# pwdecrypt_fun = getattr(
# importlib.import_module(
# module_fun[0]),
# module_fun[1])
#
# return pwdecrypt_fun(private_key, password)
#
#
# def decrypt_database_config(
# database=None, passwd_key="password", name="database"):
# if not database:
# database = get_base_config(name, {})
#
# database[passwd_key] = decrypt_database_password(database[passwd_key])
# return database
#
#
# def update_config(key, value, conf_name=SERVICE_CONF):
# conf_path = conf_realpath(conf_name=conf_name)
# if not os.path.isabs(conf_path):
# conf_path = os.path.join(
# file_utils.get_project_base_directory(), conf_path)
#
# with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
# config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
# config[key] = value
# file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
#
#
#
#
#
# def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
# return datetime.datetime(date_time.year, date_time.month, date_time.day,
# date_time.hour, date_time.minute, date_time.second)
#
#
# def get_format_time() -> datetime.datetime:
# return datetime_format(datetime.datetime.now())
#
#
# def str2date(date_time: str):
# return datetime.datetime.strptime(date_time, '%Y-%m-%d')
#
#
# def elapsed2time(elapsed):
# seconds = elapsed / 1000
# minuter, second = divmod(seconds, 60)
# hour, minuter = divmod(minuter, 60)
# return '%02d:%02d:%02d' % (hour, minuter, second)
#
#
#
# def download_img(url):
# if not url:
# return ""
# response = requests.get(url)
# return "data:" + \
# response.headers.get('Content-Type', 'image/jpg') + ";" + \
# "base64," + base64.b64encode(response.content).decode("utf-8")
#
#
# def delta_seconds(date_string: str):
# dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
# return (datetime.datetime.now() - dt).total_seconds()
#
#
# def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
# return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
#
#
#
#

@ -0,0 +1,102 @@
import json
import os
from cachetools import LRUCache, cached
from ruamel.yaml import YAML
PROJECT_BASE = os.getenv("PROJECT_BASE") or os.getenv("DEPLOY_BASE")
RAG_BASE = os.getenv("BASE")
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
PROJECT_BASE = os.path.dirname(os.path.abspath(os.path.join(__file__, '..')))
if args:
return os.path.join(PROJECT_BASE, *args)
return PROJECT_BASE
def join_project_base_path(relative_path):
base_path=get_project_base_directory()
return os.path.join(base_path, relative_path)
def get_rag_directory(*args):
global RAG_BASE
if RAG_BASE is None:
RAG_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.pardir,
os.pardir,
)
)
if args:
return os.path.join(RAG_BASE, *args)
return RAG_BASE
@cached(cache=LRUCache(maxsize=10))
def load_json_conf(conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path) as f:
return json.load(f)
except BaseException:
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
def dump_json_conf(config_data, conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path, "w") as f:
json.dump(config_data, f, indent=4)
except BaseException:
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
def load_json_conf_real_time(conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path) as f:
return json.load(f)
except BaseException:
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
def load_yaml_conf(conf_path):
if not os.path.isabs(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(conf_path) as f:
yaml = YAML(typ="safe", pure=True)
return yaml.load(f)
except Exception as e:
raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e)
def rewrite_yaml_conf(conf_path, config):
if not os.path.isabs(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(conf_path, "w") as f:
yaml = YAML(typ="safe")
yaml.dump(config, f)
except Exception as e:
raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e)
def rewrite_json_file(filepath, json_data):
with open(filepath, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4, separators=(",", ": "))
f.close()

1431
uv.lock

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save