You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.6 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import os
import sys
from dotenv import load_dotenv
from functools import lru_cache
from pydantic import computed_field
from pydantic_settings import BaseSettings
from typing import Literal
class GetConfig:
def __init__(self):
self.parse_cli_args()
@lru_cache()
def get_llm_config(self):
return LlmSettings()
@lru_cache()
def get_emb_config(self):
return EmbSettings()
@lru_cache()
def get_mysql_database_config(self):
"""
获取数据库配置
"""
# 实例化数据库配置模型
return MysqlDataBaseSettings()
@lru_cache()
def get_pgvector_database_config(self):
"""
获取数据库配置
"""
# 实例化数据库配置模型
return PgvectorDataBaseSettings()
@staticmethod
def parse_cli_args():
"""
解析命令行参数
"""
if 'uvicorn' in sys.argv[0]:
# 使用uvicorn启动时命令行参数需要按照uvicorn的文档进行配置无法自定义参数
pass
else:
# 使用argparse定义命令行参数
parser = argparse.ArgumentParser(description='命令行参数')
parser.add_argument('--env', type=str, default='', help='运行环境')
# 解析命令行参数
args = parser.parse_args()
# 设置环境变量如果未设置命令行参数默认APP_ENV为dev
os.environ['APP_ENV'] = args.env if args.env else 'dev'
# 读取运行环境
run_env = os.environ.get('APP_ENV', '')
# 运行环境未指定时默认加载.env.dev
env_file = '.env.dev'
# 运行环境不为空时按命令行参数加载对应.env文件
if run_env != '':
env_file = f'.env.{run_env}'
# 加载配置
load_dotenv(env_file)
print(f'当前运行环境为:{run_env}')
print(os.getenv("DB_HOST"))
# 实例化获取配置类
get_config = GetConfig()
class LlmSettings:
"""
AI配置
"""
base_url: str = os.getenv('llm_base_url', '')
api_key: str = os.getenv('llm_api_key', '')
model_name: str = os.getenv('llm_model', '')
class EmbSettings:
"""
AI配置
"""
base_url: str = os.getenv('emb_base_url', '')
api_key: str = os.getenv('emb_api_key', '')
model_name: str = os.getenv('emb_model', '')
class MysqlDataBaseSettings:
"""
数据库配置
"""
# 数据库配置
db_host: str = os.getenv('MYSQL_DB_HOST', '')
db_port: int = int(os.getenv('MYSQL_DB_PORT', 33306))
db_username: str = os.getenv('MYSQL_DB_USERNAME', '')
db_password: str = os.getenv('MYSQL_DB_PASSWORD', '')
db_database: str = os.getenv('MYSQL_DB_DATABASE', '')
db_echo: bool = True
db_max_overflow: int = 10
db_pool_size: int = 50
db_pool_recycle: int = 3600
db_pool_timeout: int = 30
class PgvectorDataBaseSettings:
"""
数据库配置
"""
# 数据库配置
db_host: str = os.getenv('PG_DB_HOST', '')
db_port: int = os.getenv('PG_DB_PORT', 5432)
db_username: str = os.getenv('PG_DB_USERNAME', '')
db_password: str = os.getenv('PG_DB_PASSWORD', '')
db_database: str = os.getenv('PG_DB_DATABASE', '')
db_echo: bool = True
db_max_overflow: int = 10
db_pool_size: int = 50
db_pool_recycle: int = 3600
db_pool_timeout: int = 30
# mysql数据库配置
MysqlDataBaseConfig = get_config.get_mysql_database_config()
# pgvector数据库配置
PgvectorDataBaseConfig = get_config.get_pgvector_database_config()
LlmBaseConfig = get_config.get_llm_config()
EmbBaseConfig = get_config.get_emb_config()