init: 整合基础框架,fastapi+sqlmodel(增强逻辑删除),搭建部分基础环境
commit
e25b78f77b
@ -0,0 +1,3 @@
|
||||
SERVICE_CONF=application.yaml
|
||||
DEBUG=true
|
||||
DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name
|
||||
@ -0,0 +1,173 @@
|
||||
# volumes
|
||||
docker/volumes/**
|
||||
docker/nginx/dist
|
||||
web/dist
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
coverage.json
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env-local
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.idea/'
|
||||
|
||||
.DS_Store
|
||||
web/.vscode/settings.json
|
||||
|
||||
# Intellij IDEA Files
|
||||
.idea/*
|
||||
!.idea/vcs.xml
|
||||
!.idea/icon.png
|
||||
.ideaDataSources/
|
||||
*.iml
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
||||
application.yaml
|
||||
@ -0,0 +1,3 @@
|
||||
zhipu:
|
||||
api-key:
|
||||
base-url:
|
||||
@ -0,0 +1,9 @@
|
||||
class MetaConst(type):
|
||||
def __setattr__(cls, name, value):
|
||||
if name in cls.__dict__:
|
||||
raise TypeError(f"Cannot rebind constant {name}")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
class Constant(metaclass=MetaConst):
|
||||
LOGICAL_DELETE_FIELD = "is_deleted"
|
||||
@ -0,0 +1,6 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class IsDelete(IntEnum):
|
||||
NO_DELETE = 0
|
||||
DELETE = 1
|
||||
@ -0,0 +1,8 @@
|
||||
from config.settings import Settings, get_settings
|
||||
settings = get_settings()
|
||||
|
||||
def get_yaml_conf(k):
|
||||
return settings.yaml_config.get(k, None)
|
||||
|
||||
def show_configs():
|
||||
return settings
|
||||
@ -0,0 +1,57 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic.v1 import BaseSettings as Base
|
||||
|
||||
from utils import file_utils
|
||||
|
||||
|
||||
class BaseSettings(Base):
|
||||
"""配置基类"""
|
||||
|
||||
class Config:
|
||||
env_file = '.env'
|
||||
env_file_encoding = 'utf-8'
|
||||
extra = 'allow'
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置
|
||||
server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写)
|
||||
eg.
|
||||
config.env 文件内写入
|
||||
REDIS_URL='redis://localhost:6379'
|
||||
上述环境变量会覆盖 redis_url
|
||||
"""
|
||||
# 模式
|
||||
mode: str = 'dev' # dev, prod
|
||||
debug: bool = False # dev, prod
|
||||
service_conf: str = 'application.yaml' # dev, prod
|
||||
# 版本
|
||||
api_version: str = '/v1'
|
||||
# 时区
|
||||
timezone: str = 'Asia/Shanghai'
|
||||
# 日期时间格式
|
||||
datetime_fmt: str = '%Y-%m-%d %H:%M:%S'
|
||||
# Redis键前缀
|
||||
redis_prefix: str = 'agent:'
|
||||
# 当前域名
|
||||
host_ip:str = '0.0.0.0'
|
||||
host_port: int = 8080
|
||||
#sql驱动连接
|
||||
database_url: str = ''
|
||||
# yaml配置
|
||||
yaml_config: dict = {}
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取并缓存应用配置"""
|
||||
# 读取server目录下的配置
|
||||
load_dotenv()
|
||||
settings = Settings()
|
||||
yaml_config = file_utils.load_yaml_conf(settings.service_conf)
|
||||
# 将YAML配置存储到Settings实例中
|
||||
settings.yaml_config = yaml_config
|
||||
for k, v in settings.yaml_config.items():
|
||||
if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default:
|
||||
setattr(settings, k, v)
|
||||
return settings
|
||||
@ -0,0 +1,6 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None)
|
||||
@ -0,0 +1,21 @@
|
||||
from sqlalchemy import Column, BigInteger
|
||||
|
||||
from common.global_enums import IsDelete
|
||||
from utils import current_timestamp
|
||||
|
||||
# 上下文变量,控制是否启用逻辑删除过滤
|
||||
_SOFT_DELETE_ENABLED = True
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class DbBaseModel(SQLModel, table=False):
|
||||
id: str = Field(default=None, max_length=32, primary_key=True)
|
||||
created_time: int = Field(sa_column=Column(BigInteger), default_factory=current_timestamp)
|
||||
# created_by = CharField(max_length=32, index=True)
|
||||
updated_time: int = Field(sa_column=Column(BigInteger, onupdate=current_timestamp),
|
||||
default_factory=current_timestamp)
|
||||
# updated_by = CharField(max_length=32)
|
||||
is_deleted: int = Field(default=IsDelete.NO_DELETE)
|
||||
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
@ -0,0 +1,7 @@
|
||||
from entity.base_entity import DbBaseModel
|
||||
|
||||
|
||||
class User(DbBaseModel,table=True):
|
||||
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
|
||||
username: str
|
||||
password: str
|
||||
@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from config import settings, show_configs
|
||||
from utils import file_utils
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.info(r"""
|
||||
_____ _ ____ _ _ _
|
||||
| __ \ | | | _ \ (_) | | |
|
||||
| | | | __ _| |_ __ _| |_) |_ _ _| | __| | ___ _ __
|
||||
| | | |/ _` | __/ _` | _ <| | | | | |/ _` |/ _ \ '__|
|
||||
| |__| | (_| | || (_| | |_) | |_| | | | (_| | __/ |
|
||||
|_____/ \__,_|\__\__,_|____/ \__,_|_|_|\__,_|\___|_|
|
||||
""")
|
||||
|
||||
logging.info(
|
||||
f'project base: {file_utils.get_project_base_directory()}'
|
||||
)
|
||||
show_configs()
|
||||
# import argparse
|
||||
# parser = argparse.ArgumentParser()
|
||||
#
|
||||
# args = parser.parse_args()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
try:
|
||||
logging.info("服务启动ing...")
|
||||
import uvicorn
|
||||
# 配置Uvicorn参数
|
||||
uvicorn_config = {
|
||||
# "app": app, # FastAPI应用实例
|
||||
"app": "config.fastapi_config:app", # FastAPI应用实例
|
||||
"host": settings.host_ip,
|
||||
"port": settings.host_port,
|
||||
"reload": settings.debug, # 开发模式启用热重载
|
||||
"log_level": "debug" if settings.debug else "info",
|
||||
"access_log": True,
|
||||
}
|
||||
# 如果是调试模式,添加额外配置
|
||||
if settings.debug:
|
||||
uvicorn_config.update({
|
||||
"reload_dirs": ["."], # 监视当前目录变化
|
||||
"reload_delay": 0.5, # 重载延迟
|
||||
})
|
||||
|
||||
# 启动Uvicorn服务器
|
||||
uvicorn.run(**uvicorn_config)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
@ -0,0 +1,20 @@
|
||||
from fastapi import Request
|
||||
|
||||
from core.global_context import current_session
|
||||
from entity import AsyncSessionLocal
|
||||
|
||||
|
||||
async def db_session_middleware(request: Request, call_next):
|
||||
async with AsyncSessionLocal() as session:
|
||||
# 设置会话到上下文变量
|
||||
token = current_session.set(session)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
# 重置上下文变量
|
||||
current_session.reset(token)
|
||||
return response
|
||||
@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "rg-agno-agent"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"langchain>=0.3.27",
|
||||
"mcp>=1.14.0",
|
||||
"openai>=1.107.3",
|
||||
"sqlalchemy[asyncio]>=2.0.43",
|
||||
"ruamel-yaml>=0.18.6,<0.19.0", # YAML处理
|
||||
"cachetools==5.3.3", # 缓存工具
|
||||
"filelock==3.15.4", # 文件锁
|
||||
"itsdangerous==2.1.2", # 安全签名,用于 SessionMiddleware
|
||||
"langchain-openai>=0.3.33",
|
||||
"langchainhub>=0.1.21",
|
||||
"httpx-sse>=0.4.1",
|
||||
"sqlmodel>=0.0.25",
|
||||
"aiomysql>=0.2.0",
|
||||
]
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
@ -0,0 +1,13 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from config import settings
|
||||
from utils.file_utils import get_project_base_directory
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/test")
|
||||
async def test():
|
||||
return {
|
||||
"settings":settings,
|
||||
"base_path": get_project_base_directory(),
|
||||
}
|
||||
@ -0,0 +1,461 @@
|
||||
from typing import Union, Type, List, Any, TypeVar, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.global_context import current_session
|
||||
from utils import get_uuid, current_timestamp
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseService(Generic[T]):
|
||||
model: Type[T] # 子类必须指定模型
|
||||
|
||||
@classmethod
|
||||
def get_db(cls) -> AsyncSession:
|
||||
"""获取当前请求的会话"""
|
||||
session = current_session.get()
|
||||
if session is None:
|
||||
raise RuntimeError("No database session in context. "
|
||||
"Make sure to use this service within a request context.")
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
async def create(cls, **kwargs) -> T:
|
||||
"""通用创建方法"""
|
||||
obj = cls.model(**kwargs)
|
||||
db = cls.get_db()
|
||||
db.add(obj)
|
||||
await db.flush()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[
|
||||
BaseModel, List[BaseModel]]:
|
||||
dto_list = []
|
||||
if not isinstance(entity_data, list):
|
||||
return dto(**entity_data.to_dict())
|
||||
for entity in entity_data:
|
||||
temp = entity
|
||||
if not isinstance(entity, dict):
|
||||
temp = entity.to_dict()
|
||||
dto_list.append(dto(**temp))
|
||||
return dto_list
|
||||
|
||||
@classmethod
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
"""Execute a database query with optional column selection and ordering.
|
||||
|
||||
This method provides a flexible way to query the database with various filters
|
||||
and sorting options. It supports column selection, sort order control, and
|
||||
additional filter conditions.
|
||||
|
||||
Args:
|
||||
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||
order_by (str, optional): Column name to sort results by.
|
||||
**kwargs: Additional filter conditions passed as keyword arguments.
|
||||
|
||||
Returns:
|
||||
peewee.ModelSelect: A query result containing matching records.
|
||||
"""
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||
"""Retrieve all records from the database with optional column selection and ordering.
|
||||
|
||||
This method fetches all records from the model's table with support for
|
||||
column selection and result ordering. If no order_by is specified and reverse
|
||||
is True, it defaults to ordering by created_time.
|
||||
|
||||
Args:
|
||||
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||
order_by (str, optional): Column name to sort results by. Defaults to 'created_time' if reverse is specified.
|
||||
|
||||
Returns:
|
||||
peewee.ModelSelect: A query containing all matching records.
|
||||
"""
|
||||
if cols:
|
||||
query_records = cls.model.select(*cols)
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "created_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
def get(cls, **kwargs):
|
||||
"""Get a single record matching the given criteria.
|
||||
|
||||
This method retrieves a single record from the database that matches
|
||||
the specified filter conditions.
|
||||
|
||||
Args:
|
||||
**kwargs: Filter conditions as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: Single matching record.
|
||||
|
||||
Raises:
|
||||
peewee.DoesNotExist: If no matching record is found.
|
||||
"""
|
||||
return cls.model.get(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_or_none(cls, **kwargs):
|
||||
"""Get a single record or None if not found.
|
||||
|
||||
This method attempts to retrieve a single record matching the given criteria,
|
||||
returning None if no match is found instead of raising an exception.
|
||||
|
||||
Args:
|
||||
**kwargs: Filter conditions as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance or None: Matching record if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
return cls.model.get(**kwargs)
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
|
||||
if not isinstance(query_params, dict):
|
||||
query_params = query_params.model_dump()
|
||||
query_params = {k: v for k, v in query_params.items() if v is not None}
|
||||
sessions = cls.get_query_session(query_params)
|
||||
return cls.auto_page(sessions, query_params)
|
||||
|
||||
@classmethod
|
||||
def auto_page(cls, sessions, query_params: Union[dict, BasePageQueryReq] = None,
|
||||
dto_model_class: Type[BaseModel] = None):
|
||||
if not query_params:
|
||||
query_params = {}
|
||||
if not isinstance(query_params, dict):
|
||||
query_params = query_params.model_dump()
|
||||
page_number = query_params.get("page_number", 1)
|
||||
page_size = query_params.get("page_size", 12)
|
||||
desc = query_params.get("desc", "desc")
|
||||
orderby = query_params.get("orderby", "created_time")
|
||||
data_count = sessions.count()
|
||||
if data_count == 0:
|
||||
return BasePageResp(**{
|
||||
"page_number": page_number,
|
||||
"page_size": page_size,
|
||||
"count": data_count,
|
||||
"desc": desc,
|
||||
"orderby": orderby,
|
||||
"data": [],
|
||||
})
|
||||
if desc == "desc":
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
sessions = sessions.paginate(int(page_number), int(page_size))
|
||||
datas = list(sessions.dicts())
|
||||
result = datas
|
||||
if dto_model_class is not None:
|
||||
result = [dto_model_class(**item) for item in datas]
|
||||
return BasePageResp(**{
|
||||
"page_number": page_number,
|
||||
"page_size": page_size,
|
||||
"count": data_count,
|
||||
"desc": desc,
|
||||
"orderby": orderby,
|
||||
"data": result,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_list(cls, query_params: Union[dict, BaseQueryReq]):
|
||||
if not isinstance(query_params, dict):
|
||||
query_params = query_params.model_dump()
|
||||
query_params = {k: v for k, v in query_params.items() if v is not None}
|
||||
desc = query_params.get("desc", "desc")
|
||||
orderby = query_params.get("orderby", "created_time")
|
||||
sessions = cls.get_query_session(query_params)
|
||||
if desc == "desc":
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
|
||||
if not isinstance(query_params, dict):
|
||||
query_params = query_params.model_dump()
|
||||
query_params = {k: v for k, v in query_params.items() if v is not None}
|
||||
desc = query_params.get("desc", "desc")
|
||||
orderby = query_params.get("orderby", "created_time")
|
||||
sessions = cls.model.select(cls.model.id)
|
||||
sessions = cls.get_query_session(query_params, sessions)
|
||||
if desc == "desc":
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
return [item["id"] for item in list(sessions.dicts())]
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
"""Save a new record to database.
|
||||
|
||||
This method creates a new record in the database with the provided field values,
|
||||
forcing an insert operation rather than an update.
|
||||
|
||||
Args:
|
||||
**kwargs: Record field values as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: The created record object.
|
||||
"""
|
||||
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj > 0
|
||||
|
||||
@classmethod
|
||||
def insert(cls, **kwargs):
|
||||
"""Insert a new record with automatic ID and timestamps.
|
||||
|
||||
This method creates a new record with automatically generated ID and timestamp fields.
|
||||
It handles the creation of created_time, create_date, updated_time, and update_date fields.
|
||||
|
||||
Args:
|
||||
**kwargs: Record field values as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: The newly created record object.
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
|
||||
kwargs["created_time"] = current_timestamp()
|
||||
# kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["updated_time"] = current_timestamp()
|
||||
# kwargs["update_date"] = datetime_format(datetime.now())
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj > 0
|
||||
|
||||
@classmethod
|
||||
def insert_many(cls, data_list, batch_size=100):
|
||||
"""Insert multiple records in batches.
|
||||
|
||||
This method efficiently inserts multiple records into the database using batch processing.
|
||||
It automatically sets creation timestamps for all records.
|
||||
|
||||
Args:
|
||||
data_list (list): List of dictionaries containing record data to insert.
|
||||
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
||||
"""
|
||||
|
||||
with DB.atomic():
|
||||
for d in data_list:
|
||||
|
||||
if not d.get("id", None):
|
||||
d["id"] = get_uuid()
|
||||
d["created_time"] = current_timestamp()
|
||||
# d["create_date"] = datetime_format(datetime.now())
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i: i + batch_size]).execute()
|
||||
|
||||
@classmethod
|
||||
def update_many_by_id(cls, data_list):
|
||||
"""Update multiple records by their IDs.
|
||||
|
||||
This method updates multiple records in the database, identified by their IDs.
|
||||
It automatically updates the updated_time and update_date fields for each record.
|
||||
|
||||
Args:
|
||||
data_list (list): List of dictionaries containing record data to update.
|
||||
Each dictionary must include an 'id' field.
|
||||
"""
|
||||
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["updated_time"] = current_timestamp()
|
||||
# data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
def updated_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# data: Updated field values
|
||||
# Returns:
|
||||
# Number of records updated
|
||||
data["updated_time"] = current_timestamp()
|
||||
# data["update_date"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num > 0
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, pid):
|
||||
# Get a record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# Returns:
|
||||
# Tuple of (success, record)
|
||||
try:
|
||||
obj = cls.model.get_or_none(cls.model.id == pid)
|
||||
if obj:
|
||||
return True, obj
|
||||
except Exception:
|
||||
pass
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
def get_by_ids(cls, pids, cols=None):
|
||||
# Get multiple records by their IDs
|
||||
# Args:
|
||||
# pids: List of record IDs
|
||||
# cols: List of columns to select
|
||||
# Returns:
|
||||
# Query of matching records
|
||||
if cols:
|
||||
objs = cls.model.select(*cols)
|
||||
else:
|
||||
objs = cls.model.select()
|
||||
return objs.where(cls.model.id.in_(pids))
|
||||
|
||||
@classmethod
|
||||
def get_last_by_create_time(cls):
|
||||
# Get multiple records by their IDs
|
||||
# Args:
|
||||
# pids: List of record IDs
|
||||
# cols: List of columns to select
|
||||
# Returns:
|
||||
# Query of matching records
|
||||
latest = cls.model.select().order_by(cls.model.created_time.desc()).first()
|
||||
return latest
|
||||
|
||||
@classmethod
|
||||
def delete_by_id(cls, pid):
|
||||
# Delete a record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
@classmethod
|
||||
def delete_by_ids(cls, pids):
|
||||
# Delete multiple records by their IDs
|
||||
# Args:
|
||||
# pids: List of record IDs
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
with DB.atomic():
|
||||
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def filter_delete(cls, filters):
|
||||
# Delete records matching given filters
|
||||
# Args:
|
||||
# filters: List of filter conditions
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
with DB.atomic():
|
||||
num = cls.model.delete().where(*filters).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
def filter_update(cls, filters, update_data):
|
||||
# Update records matching given filters
|
||||
# Args:
|
||||
# filters: List of filter conditions
|
||||
# update_data: Updated field values
|
||||
# Returns:
|
||||
# Number of records updated
|
||||
with DB.atomic():
|
||||
return cls.model.update(update_data).where(*filters).execute()
|
||||
|
||||
@staticmethod
|
||||
def cut_list(tar_list, n):
|
||||
# Split a list into chunks of size n
|
||||
# Args:
|
||||
# tar_list: List to split
|
||||
# n: Chunk size
|
||||
# Returns:
|
||||
# List of tuples containing chunks
|
||||
length = len(tar_list)
|
||||
arr = range(length)
|
||||
result = [tuple(tar_list[x: (x + n)]) for x in arr[::n]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
# Get records matching IN clause filters with optional column selection
|
||||
# Args:
|
||||
# in_key: Field name for IN clause
|
||||
# in_filters_list: List of values for IN clause
|
||||
# filters: Additional filter conditions
|
||||
# cols: List of columns to select
|
||||
# Returns:
|
||||
# List of matching records
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
||||
|
||||
@classmethod
|
||||
def get_query_session(cls, query_params, sessions=None):
|
||||
|
||||
if sessions is None:
|
||||
sessions = cls.model.select()
|
||||
for key, value in query_params.items():
|
||||
if hasattr(cls.model, key):
|
||||
field = getattr(cls.model, key)
|
||||
sessions = sessions.where(field == value)
|
||||
return sessions
|
||||
|
||||
@classmethod
|
||||
def get_data_count(cls, query_params: dict = None):
|
||||
if not query_params:
|
||||
raise Exception("参数为空")
|
||||
sessions = cls.get_query_session(query_params)
|
||||
return sessions.count()
|
||||
|
||||
@classmethod
|
||||
def is_exist(cls, query_params: dict = None):
|
||||
return cls.get_data_count(query_params) > 0
|
||||
|
||||
@classmethod
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# data: Updated field values
|
||||
# Returns:
|
||||
# Number of records updated
|
||||
data["updated_time"] = current_timestamp()
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
def check_base_permission(cls, model_data):
|
||||
if isinstance(model_data, dict):
|
||||
if model_data.get("created_by") != get_current_user().id:
|
||||
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
|
||||
if model_data.created_by != get_current_user().id:
|
||||
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
|
||||
@ -0,0 +1,16 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
class ChatService:
|
||||
llm = ChatOpenAI(
|
||||
temperature=0.6,
|
||||
model="glm-4.5",
|
||||
openai_api_key=settings.yaml_config.get("zhipu").get("api-key"),
|
||||
openai_api_base=settings.yaml_config.get("zhipu").get("base-url")
|
||||
)
|
||||
@classmethod
|
||||
def chat(cls,message):
|
||||
resp = cls.llm.invoke(message)
|
||||
print(resp)
|
||||
@ -0,0 +1,62 @@
|
||||
from contextvars import ContextVar
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, TypeVar, Generic, Type
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from entity import AsyncSessionLocal
|
||||
from entity.user import User
|
||||
|
||||
# 1. 创建上下文变量存储当前会话
|
||||
current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None)
|
||||
|
||||
# 3. 中间件:管理请求生命周期和会话
|
||||
async def db_session_middleware(request: Request, call_next):
|
||||
async with AsyncSessionLocal() as session:
|
||||
# 设置会话到上下文变量
|
||||
token = current_session.set(session)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
# 重置上下文变量
|
||||
current_session.reset(token)
|
||||
return response
|
||||
|
||||
# 4. 服务基类
|
||||
T = TypeVar('T')
|
||||
|
||||
class BaseService(Generic[T]):
|
||||
model: Type[T] = None # 子类必须指定模型
|
||||
|
||||
@classmethod
|
||||
def get_db(cls) -> AsyncSession:
|
||||
"""获取当前请求的会话"""
|
||||
session = current_session.get()
|
||||
if session is None:
|
||||
raise RuntimeError("No database session in context. "
|
||||
"Make sure to use this service within a request context.")
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
async def create(cls, **kwargs) -> T:
|
||||
"""通用创建方法"""
|
||||
obj = cls.model(**kwargs)
|
||||
db = cls.get_db()
|
||||
db.add(obj)
|
||||
await db.flush()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
async def get(cls, id: int) -> Optional[T]:
|
||||
"""通用获取方法"""
|
||||
db = cls.get_db()
|
||||
return await db.get(cls.model, id)
|
||||
|
||||
# 5. 具体服务类
|
||||
class UserService(BaseService[User]):
|
||||
model = User # 指定模型
|
||||
@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import select
|
||||
|
||||
from entity import AsyncSessionLocal, close_engine
|
||||
from entity.user import User
|
||||
|
||||
|
||||
async def add():
|
||||
add_list = []
|
||||
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
|
||||
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
|
||||
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
|
||||
# session = get_db_session()
|
||||
async with AsyncSessionLocal() as session:
|
||||
for user in add_list:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await close_engine()
|
||||
print("完成")
|
||||
|
||||
|
||||
async def remove():
|
||||
add_list = []
|
||||
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
|
||||
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
|
||||
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
|
||||
async with AsyncSessionLocal() as session:
|
||||
for i in add_list:
|
||||
delete_sql = delete(User).where(User.id == i.id)
|
||||
await session.execute(delete_sql)
|
||||
await session.commit()
|
||||
await close_engine()
|
||||
print("======================删除完成=================")
|
||||
async def remove_one():
|
||||
add_list = []
|
||||
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
|
||||
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
|
||||
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
|
||||
async with AsyncSessionLocal() as session:
|
||||
# for i in add_list:
|
||||
# delete_sql = delete(User).where(User.id == i.id)
|
||||
user=await session.execute(select(User))
|
||||
user = user.scalars().first()
|
||||
session.delete(user)
|
||||
await session.commit()
|
||||
await close_engine()
|
||||
print("======================删除完成=================")
|
||||
async def get_all():
|
||||
async with AsyncSessionLocal() as session:
|
||||
sql = select(User)
|
||||
result = await session.execute(sql)
|
||||
print(result.scalars().all())
|
||||
await session.commit()
|
||||
await close_engine()
|
||||
print("====================get_all==================")
|
||||
def test(aa="ddd"):
|
||||
print("aa=", aa)
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(add())
|
||||
asyncio.run(get_all())
|
||||
# asyncio.run(remove_one())
|
||||
asyncio.run(remove())
|
||||
asyncio.run(get_all())
|
||||
# test(None)
|
||||
@ -0,0 +1,377 @@
|
||||
# import base64
|
||||
# import copy
|
||||
# import datetime
|
||||
# import hashlib
|
||||
# import importlib
|
||||
# import io
|
||||
# import json
|
||||
# import logging
|
||||
# import os
|
||||
# import pickle
|
||||
# import socket
|
||||
import time
|
||||
import uuid
|
||||
# from enum import Enum, IntEnum
|
||||
#
|
||||
#
|
||||
# import requests
|
||||
#
|
||||
# from fastapi.encoders import jsonable_encoder
|
||||
# from filelock import FileLock
|
||||
#
|
||||
# from common.constants import SERVICE_CONF
|
||||
# from . import file_utils
|
||||
#
|
||||
def current_timestamp() -> int:
|
||||
"""获取时间戳"""
|
||||
return int(time.time() * 1000)
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
# def conf_realpath(conf_name):
|
||||
# conf_path = f"{conf_name}"
|
||||
# return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
#
|
||||
#
|
||||
# # def read_config(conf_name=SERVICE_CONF):
|
||||
# # local_config = {}
|
||||
# # local_path = conf_realpath(f'local.{conf_name}')
|
||||
# #
|
||||
# # # load local config file
|
||||
# # if os.path.exists(local_path):
|
||||
# # local_config = file_utils.load_yaml_conf(local_path)
|
||||
# # if not isinstance(local_config, dict):
|
||||
# # raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
# #
|
||||
# # global_config_path = conf_realpath(conf_name)
|
||||
# # global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
# #
|
||||
# # if not isinstance(global_config, dict):
|
||||
# # raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
# #
|
||||
# # global_config.update(local_config)
|
||||
# # return global_config
|
||||
# #
|
||||
# #
|
||||
# # CONFIGS = read_config()
|
||||
# #
|
||||
# #
|
||||
# # def show_configs():
|
||||
# # msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
# # for k, v in CONFIGS.items():
|
||||
# # if isinstance(v, dict):
|
||||
# # if "password" in v:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # v["password"] = "*" * 8
|
||||
# # if "access_key" in v:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # v["access_key"] = "*" * 8
|
||||
# # if "secret_key" in v:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # v["secret_key"] = "*" * 8
|
||||
# # if "secret" in v:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # v["secret"] = "*" * 8
|
||||
# # if "sas_token" in v:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # v["sas_token"] = "*" * 8
|
||||
# # if "oauth" in k:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # for key, val in v.items():
|
||||
# # if "client_secret" in val:
|
||||
# # val["client_secret"] = "*" * 8
|
||||
# # if "authentication" in k:
|
||||
# # v = copy.deepcopy(v)
|
||||
# # for key, val in v.items():
|
||||
# # if "http_secret_key" in val:
|
||||
# # val["http_secret_key"] = "*" * 8
|
||||
# # msg += f"\n\t{k}: {v}"
|
||||
# # logging.info(msg)
|
||||
#
|
||||
#
|
||||
# # def get_base_config(key, default=None):
|
||||
# # if key is None:
|
||||
# # return None
|
||||
# # if default is None:
|
||||
# # default = os.environ.get(key.upper())
|
||||
# # return CONFIGS.get(key, default)
|
||||
#
|
||||
#
|
||||
# use_deserialize_safe_module = get_base_config(
|
||||
# 'use_deserialize_safe_module', False)
|
||||
#
|
||||
#
|
||||
# class BaseType:
|
||||
# def to_dict(self):
|
||||
# return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
#
|
||||
# def to_dict_with_type(self):
|
||||
# def _dict(obj):
|
||||
# module = None
|
||||
# if issubclass(obj.__class__, BaseType):
|
||||
# data = {}
|
||||
# for attr, v in obj.__dict__.items():
|
||||
# k = attr.lstrip("_")
|
||||
# data[k] = _dict(v)
|
||||
# module = obj.__module__
|
||||
# elif isinstance(obj, (list, tuple)):
|
||||
# data = []
|
||||
# for i, vv in enumerate(obj):
|
||||
# data.append(_dict(vv))
|
||||
# elif isinstance(obj, dict):
|
||||
# data = {}
|
||||
# for _k, vv in obj.items():
|
||||
# data[_k] = _dict(vv)
|
||||
# else:
|
||||
# data = obj
|
||||
# return {"type": obj.__class__.__name__,
|
||||
# "data": data, "module": module}
|
||||
#
|
||||
# return _dict(self)
|
||||
#
|
||||
#
|
||||
# class CustomJSONEncoder(json.JSONEncoder):
|
||||
# def __init__(self, **kwargs):
|
||||
# self._with_type = kwargs.pop("with_type", False)
|
||||
# super().__init__(**kwargs)
|
||||
#
|
||||
# def default(self, obj):
|
||||
# if isinstance(obj, datetime.datetime):
|
||||
# return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
# elif isinstance(obj, datetime.date):
|
||||
# return obj.strftime('%Y-%m-%d')
|
||||
# elif isinstance(obj, datetime.timedelta):
|
||||
# return str(obj)
|
||||
# elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
# return obj.value
|
||||
# elif isinstance(obj, set):
|
||||
# return list(obj)
|
||||
# elif issubclass(type(obj), BaseType):
|
||||
# if not self._with_type:
|
||||
# return obj.to_dict()
|
||||
# else:
|
||||
# return obj.to_dict_with_type()
|
||||
# elif isinstance(obj, type):
|
||||
# return obj.__name__
|
||||
# else:
|
||||
# # return json.JSONEncoder.default(self, obj)
|
||||
# return jsonable_encoder(obj)
|
||||
#
|
||||
#
|
||||
# def rag_uuid():
|
||||
# return uuid.uuid1().hex
|
||||
#
|
||||
#
|
||||
# def string_to_bytes(string):
|
||||
# return string if isinstance(
|
||||
# string, bytes) else string.encode(encoding="utf-8")
|
||||
#
|
||||
#
|
||||
# def bytes_to_string(byte):
|
||||
# return byte.decode(encoding="utf-8")
|
||||
#
|
||||
#
|
||||
# def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
# dest = json.dumps(
|
||||
# src,
|
||||
# indent=indent,
|
||||
# cls=CustomJSONEncoder,
|
||||
# with_type=with_type)
|
||||
# if byte:
|
||||
# dest = string_to_bytes(dest)
|
||||
# return dest
|
||||
#
|
||||
#
|
||||
# def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
# if isinstance(src, bytes):
|
||||
# src = bytes_to_string(src)
|
||||
# return json.loads(src, object_hook=object_hook,
|
||||
# object_pairs_hook=object_pairs_hook)
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
# if not timestamp:
|
||||
# timestamp = time.time()
|
||||
# timestamp = int(timestamp) / 1000
|
||||
# time_array = time.localtime(timestamp)
|
||||
# str_date = time.strftime(format_string, time_array)
|
||||
# return str_date
|
||||
#
|
||||
#
|
||||
# def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
# time_array = time.strptime(time_str, format_string)
|
||||
# time_stamp = int(time.mktime(time_array) * 1000)
|
||||
# return time_stamp
|
||||
#
|
||||
#
|
||||
# def serialize_b64(src, to_str=False):
|
||||
# dest = base64.b64encode(pickle.dumps(src))
|
||||
# if not to_str:
|
||||
# return dest
|
||||
# else:
|
||||
# return bytes_to_string(dest)
|
||||
#
|
||||
#
|
||||
# def deserialize_b64(src):
|
||||
# src = base64.b64decode(
|
||||
# string_to_bytes(src) if isinstance(
|
||||
# src, str) else src)
|
||||
# if use_deserialize_safe_module:
|
||||
# return restricted_loads(src)
|
||||
# return pickle.loads(src)
|
||||
#
|
||||
#
|
||||
# safe_module = {
|
||||
# 'numpy',
|
||||
# 'rag_flow'
|
||||
# }
|
||||
#
|
||||
#
|
||||
# class RestrictedUnpickler(pickle.Unpickler):
|
||||
# def find_class(self, module, name):
|
||||
# import importlib
|
||||
# if module.split('.')[0] in safe_module:
|
||||
# _module = importlib.import_module(module)
|
||||
# return getattr(_module, name)
|
||||
# # Forbid everything else.
|
||||
# raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
# (module, name))
|
||||
#
|
||||
#
|
||||
# def restricted_loads(src):
|
||||
# """Helper function analogous to pickle.loads()."""
|
||||
# return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
#
|
||||
#
|
||||
# def get_lan_ip():
|
||||
# if os.name != "nt":
|
||||
# import fcntl
|
||||
# import struct
|
||||
#
|
||||
# def get_interface_ip(ifname):
|
||||
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# return socket.inet_ntoa(
|
||||
# fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
|
||||
#
|
||||
# ip = socket.gethostbyname(socket.getfqdn())
|
||||
# if ip.startswith("127.") and os.name != "nt":
|
||||
# interfaces = [
|
||||
# "bond1",
|
||||
# "eth0",
|
||||
# "eth1",
|
||||
# "eth2",
|
||||
# "wlan0",
|
||||
# "wlan1",
|
||||
# "wifi0",
|
||||
# "ath0",
|
||||
# "ath1",
|
||||
# "ppp0",
|
||||
# ]
|
||||
# for ifname in interfaces:
|
||||
# try:
|
||||
# ip = get_interface_ip(ifname)
|
||||
# break
|
||||
# except IOError:
|
||||
# pass
|
||||
# return ip or ''
|
||||
#
|
||||
#
|
||||
# def from_dict_hook(in_dict: dict):
|
||||
# if "type" in in_dict and "data" in in_dict:
|
||||
# if in_dict["module"] is None:
|
||||
# return in_dict["data"]
|
||||
# else:
|
||||
# return getattr(importlib.import_module(
|
||||
# in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
||||
# else:
|
||||
# return in_dict
|
||||
#
|
||||
#
|
||||
# def decrypt_database_password(password):
|
||||
# encrypt_password = get_base_config("encrypt_password", False)
|
||||
# encrypt_module = get_base_config("encrypt_module", False)
|
||||
# private_key = get_base_config("private_key", None)
|
||||
#
|
||||
# if not password or not encrypt_password:
|
||||
# return password
|
||||
#
|
||||
# if not private_key:
|
||||
# raise ValueError("No private key")
|
||||
#
|
||||
# module_fun = encrypt_module.split("#")
|
||||
# pwdecrypt_fun = getattr(
|
||||
# importlib.import_module(
|
||||
# module_fun[0]),
|
||||
# module_fun[1])
|
||||
#
|
||||
# return pwdecrypt_fun(private_key, password)
|
||||
#
|
||||
#
|
||||
# def decrypt_database_config(
|
||||
# database=None, passwd_key="password", name="database"):
|
||||
# if not database:
|
||||
# database = get_base_config(name, {})
|
||||
#
|
||||
# database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
# return database
|
||||
#
|
||||
#
|
||||
# def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
# conf_path = conf_realpath(conf_name=conf_name)
|
||||
# if not os.path.isabs(conf_path):
|
||||
# conf_path = os.path.join(
|
||||
# file_utils.get_project_base_directory(), conf_path)
|
||||
#
|
||||
# with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
# config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
# config[key] = value
|
||||
# file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
#
|
||||
#
|
||||
|
||||
#
|
||||
#
|
||||
#
|
||||
# def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
||||
# return datetime.datetime(date_time.year, date_time.month, date_time.day,
|
||||
# date_time.hour, date_time.minute, date_time.second)
|
||||
#
|
||||
#
|
||||
# def get_format_time() -> datetime.datetime:
|
||||
# return datetime_format(datetime.datetime.now())
|
||||
#
|
||||
#
|
||||
# def str2date(date_time: str):
|
||||
# return datetime.datetime.strptime(date_time, '%Y-%m-%d')
|
||||
#
|
||||
#
|
||||
# def elapsed2time(elapsed):
|
||||
# seconds = elapsed / 1000
|
||||
# minuter, second = divmod(seconds, 60)
|
||||
# hour, minuter = divmod(minuter, 60)
|
||||
# return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||
#
|
||||
#
|
||||
#
|
||||
# def download_img(url):
|
||||
# if not url:
|
||||
# return ""
|
||||
# response = requests.get(url)
|
||||
# return "data:" + \
|
||||
# response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
# "base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
#
|
||||
#
|
||||
# def delta_seconds(date_string: str):
|
||||
# dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
|
||||
# return (datetime.datetime.now() - dt).total_seconds()
|
||||
#
|
||||
#
|
||||
# def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
# return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
@ -0,0 +1,102 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from cachetools import LRUCache, cached
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
PROJECT_BASE = os.getenv("PROJECT_BASE") or os.getenv("DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("BASE")
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.dirname(os.path.abspath(os.path.join(__file__, '..')))
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
def join_project_base_path(relative_path):
|
||||
base_path=get_project_base_directory()
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
|
||||
def get_rag_directory(*args):
|
||||
global RAG_BASE
|
||||
if RAG_BASE is None:
|
||||
RAG_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
if args:
|
||||
return os.path.join(RAG_BASE, *args)
|
||||
return RAG_BASE
|
||||
|
||||
@cached(cache=LRUCache(maxsize=10))
|
||||
def load_json_conf(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
|
||||
|
||||
|
||||
def dump_json_conf(config_data, conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
except BaseException:
|
||||
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
|
||||
|
||||
|
||||
def load_json_conf_real_time(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
|
||||
|
||||
|
||||
def load_yaml_conf(conf_path):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path) as f:
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
return yaml.load(f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e)
|
||||
|
||||
|
||||
def rewrite_yaml_conf(conf_path, config):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path, "w") as f:
|
||||
yaml = YAML(typ="safe")
|
||||
yaml.dump(config, f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e)
|
||||
|
||||
|
||||
def rewrite_json_file(filepath, json_data):
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, indent=4, separators=(",", ": "))
|
||||
f.close()
|
||||
Loading…
Reference in New Issue