import psycopg2 from psycopg2.extras import execute_values import numpy as np from env import PgvectorDataBaseConfig from util.log_util import logger # 数据库连接配置 DB_CONFIG = { "host": PgvectorDataBaseConfig.db_host, "database": PgvectorDataBaseConfig.db_database, "user": PgvectorDataBaseConfig.db_username, "password": PgvectorDataBaseConfig.db_password, "port": PgvectorDataBaseConfig.db_port } def connect_to_db(): """建立数据库连接""" try: conn = psycopg2.connect(**DB_CONFIG) logger.info("成功连接到数据库") return conn except Exception as e: logger.error(f"连接数据库失败: {e}") return None def setup_vector_extension(conn): """设置pgvector扩展和向量表""" try: with conn.cursor() as cur: # 启用pgvector扩展 cur.execute("CREATE EXTENSION IF NOT EXISTS vector") # 创建包含向量的表 cur.execute(""" CREATE TABLE IF NOT EXISTS documents( id SERIAL PRIMARY KEY, table_name_eng varchar(50), description TEXT, embedding vector(1024) -- 支持最大1536维的向量 ); """) cur.execute(""" CREATE TABLE IF NOT EXISTS "case"( id SERIAL PRIMARY KEY, question TEXT, answer TEXT, embedding vector(1024) -- 支持最大1536维的向量 ); """) conn.commit() # print("成功设置向量扩展和表结构") except Exception as e: logger.error(f"设置数据库结构失败: {e}") conn.rollback() def insert_vectors(conn, data): """插入向量数据""" try: with conn.cursor() as cur: # 数据格式: [(content, embedding), ...] execute_values( cur, "INSERT INTO documents (table_name_eng, description, embedding) VALUES %s", data ) conn.commit() logger.info(f"成功插入 {len(data)} 条向量数据") except Exception as e: logger.error(f"插入数据失败: {e}") conn.rollback() def search_similar_table(conn, query_vector, limit=5): """搜索相似向量""" try: with conn.cursor() as cur: cur.execute(""" SELECT table_name_eng, description, embedding <=> %s::vector AS distance FROM documents ORDER BY distance LIMIT %s """, (query_vector, limit)) results = cur.fetchall() return results except Exception as e: logger.error(f"搜索相似向量失败: {e}") return [] def search_similar_case(conn, query_vector, limit=5): """搜索相似向量""" try: with conn.cursor() as cur: cur.execute(""" SELECT question, answer, embedding <-> %s::vector AS distance FROM "case" ORDER BY distance LIMIT %s """, (query_vector, limit)) results = cur.fetchall() return results except Exception as e: logger.error(f"搜索相似向量失败: {e}") return [] def main(): # 连接数据库 conn = connect_to_db() if not conn: return # 设置向量扩展 setup_vector_extension(conn) # 生成示例数据 # sample_data = [ # ("第一个文档内容", [1.0, 2.0, 3.0]), # ("第二个文档内容", [4.0, 5.0, 6.0]), # ("第三个文档内容", [7.0, 8.0, 9.0]), # ("第四个文档内容", [2.0, 3.0, 4.0]), # ("第五个文档内容", [5.0, 6.0, 7.0]), # ] # 插入数据 # insert_vectors(conn, sample_data) # 搜索相似向量 query_vector = [1.0, 1.0, 1.0] similar_docs = search_similar_table(conn, query_vector) logger.info("\n相似文档搜索结果:") for content, distance in similar_docs: logger.info(f"内容: {content}, 距离: {distance}") # 关闭连接 conn.close() if __name__ == "__main__": main()