You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
4.1 KiB
Python

import psycopg2
from psycopg2.extras import execute_values
import numpy as np
# 数据库连接配置
DB_CONFIG = {
"host": "192.168.5.30",
"database": "vectordb",
"user": "myuser",
"password": "mypassword",
"port": 5432
}
def connect_to_db():
"""建立数据库连接"""
try:
conn = psycopg2.connect(**DB_CONFIG)
print("成功连接到数据库")
return conn
except Exception as e:
print(f"连接数据库失败: {e}")
return None
def setup_vector_extension(conn):
"""设置pgvector扩展和向量表"""
try:
with conn.cursor() as cur:
# 启用pgvector扩展
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
# 创建包含向量的表
cur.execute("""
CREATE TABLE IF NOT EXISTS documents(
id SERIAL PRIMARY KEY,
table_name_eng varchar(50),
description TEXT,
embedding vector(1024) -- 支持最大1536维的向量
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS "case"(
id SERIAL PRIMARY KEY,
question TEXT,
answer TEXT,
embedding vector(1024) -- 支持最大1536维的向量
);
""")
conn.commit()
print("成功设置向量扩展和表结构")
except Exception as e:
print(f"设置数据库结构失败: {e}")
conn.rollback()
def insert_vectors(conn, data):
"""插入向量数据"""
try:
with conn.cursor() as cur:
# 数据格式: [(content, embedding), ...]
execute_values(
cur,
"INSERT INTO documents (table_name_eng, description, embedding) VALUES %s",
data
)
conn.commit()
print(f"成功插入 {len(data)} 条向量数据")
except Exception as e:
print(f"插入数据失败: {e}")
conn.rollback()
def search_similar_table(conn, query_vector, limit=5):
"""搜索相似向量"""
try:
with conn.cursor() as cur:
cur.execute("""
SELECT table_name_eng, description, embedding <=> %s::vector AS distance
FROM documents
ORDER BY distance
LIMIT %s
""", (query_vector, limit))
results = cur.fetchall()
return results
except Exception as e:
print(f"搜索相似向量失败: {e}")
return []
def search_similar_case(conn, query_vector, limit=5):
"""搜索相似向量"""
try:
with conn.cursor() as cur:
cur.execute("""
SELECT question, answer, embedding <-> %s::vector AS distance
FROM "case"
ORDER BY distance
LIMIT %s
""", (query_vector, limit))
results = cur.fetchall()
return results
except Exception as e:
print(f"搜索相似向量失败: {e}")
return []
def main():
# 连接数据库
conn = connect_to_db()
if not conn:
return
# 设置向量扩展
setup_vector_extension(conn)
# 生成示例数据
# sample_data = [
# ("第一个文档内容", [1.0, 2.0, 3.0]),
# ("第二个文档内容", [4.0, 5.0, 6.0]),
# ("第三个文档内容", [7.0, 8.0, 9.0]),
# ("第四个文档内容", [2.0, 3.0, 4.0]),
# ("第五个文档内容", [5.0, 6.0, 7.0]),
# ]
# 插入数据
# insert_vectors(conn, sample_data)
# 搜索相似向量
query_vector = [1.0, 1.0, 1.0]
similar_docs = search_similar_table(conn, query_vector)
print("\n相似文档搜索结果:")
for content, distance in similar_docs:
print(f"内容: {content}, 距离: {distance}")
# 关闭连接
conn.close()
if __name__ == "__main__":
main()