You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
import psycopg2
|
|
from psycopg2.extras import execute_values
|
|
import numpy as np
|
|
|
|
from env import PgvectorDataBaseConfig
|
|
|
|
# 数据库连接配置
|
|
DB_CONFIG = {
|
|
"host": PgvectorDataBaseConfig.db_host,
|
|
"database": PgvectorDataBaseConfig.db_database,
|
|
"user": PgvectorDataBaseConfig.db_username,
|
|
"password": PgvectorDataBaseConfig.db_password,
|
|
"port": PgvectorDataBaseConfig.db_port
|
|
}
|
|
|
|
|
|
def connect_to_db():
|
|
"""建立数据库连接"""
|
|
try:
|
|
conn = psycopg2.connect(**DB_CONFIG)
|
|
print("成功连接到数据库")
|
|
return conn
|
|
except Exception as e:
|
|
print(f"连接数据库失败: {e}")
|
|
return None
|
|
|
|
|
|
def setup_vector_extension(conn):
|
|
"""设置pgvector扩展和向量表"""
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 启用pgvector扩展
|
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
|
|
# 创建包含向量的表
|
|
cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS documents(
|
|
id SERIAL PRIMARY KEY,
|
|
table_name_eng varchar(50),
|
|
description TEXT,
|
|
embedding vector(1024) -- 支持最大1536维的向量
|
|
);
|
|
""")
|
|
|
|
cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS "case"(
|
|
id SERIAL PRIMARY KEY,
|
|
question TEXT,
|
|
answer TEXT,
|
|
embedding vector(1024) -- 支持最大1536维的向量
|
|
);
|
|
""")
|
|
conn.commit()
|
|
print("成功设置向量扩展和表结构")
|
|
except Exception as e:
|
|
print(f"设置数据库结构失败: {e}")
|
|
conn.rollback()
|
|
|
|
|
|
def insert_vectors(conn, data):
|
|
"""插入向量数据"""
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 数据格式: [(content, embedding), ...]
|
|
execute_values(
|
|
cur,
|
|
"INSERT INTO documents (table_name_eng, description, embedding) VALUES %s",
|
|
data
|
|
)
|
|
conn.commit()
|
|
print(f"成功插入 {len(data)} 条向量数据")
|
|
except Exception as e:
|
|
print(f"插入数据失败: {e}")
|
|
conn.rollback()
|
|
|
|
|
|
def search_similar_table(conn, query_vector, limit=5):
|
|
"""搜索相似向量"""
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT table_name_eng, description, embedding <=> %s::vector AS distance
|
|
FROM documents
|
|
ORDER BY distance
|
|
LIMIT %s
|
|
""", (query_vector, limit))
|
|
|
|
results = cur.fetchall()
|
|
return results
|
|
except Exception as e:
|
|
print(f"搜索相似向量失败: {e}")
|
|
return []
|
|
|
|
def search_similar_case(conn, query_vector, limit=5):
|
|
"""搜索相似向量"""
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT question, answer, embedding <-> %s::vector AS distance
|
|
FROM "case"
|
|
ORDER BY distance
|
|
LIMIT %s
|
|
""", (query_vector, limit))
|
|
|
|
results = cur.fetchall()
|
|
return results
|
|
except Exception as e:
|
|
print(f"搜索相似向量失败: {e}")
|
|
return []
|
|
|
|
def main():
|
|
# 连接数据库
|
|
conn = connect_to_db()
|
|
if not conn:
|
|
return
|
|
|
|
# 设置向量扩展
|
|
setup_vector_extension(conn)
|
|
|
|
# 生成示例数据
|
|
# sample_data = [
|
|
# ("第一个文档内容", [1.0, 2.0, 3.0]),
|
|
# ("第二个文档内容", [4.0, 5.0, 6.0]),
|
|
# ("第三个文档内容", [7.0, 8.0, 9.0]),
|
|
# ("第四个文档内容", [2.0, 3.0, 4.0]),
|
|
# ("第五个文档内容", [5.0, 6.0, 7.0]),
|
|
# ]
|
|
|
|
# 插入数据
|
|
# insert_vectors(conn, sample_data)
|
|
|
|
# 搜索相似向量
|
|
query_vector = [1.0, 1.0, 1.0]
|
|
similar_docs = search_similar_table(conn, query_vector)
|
|
|
|
print("\n相似文档搜索结果:")
|
|
for content, distance in similar_docs:
|
|
print(f"内容: {content}, 距离: {distance}")
|
|
|
|
# 关闭连接
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|