|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
from env import LlmBaseConfig, EmbBaseConfig
|
|
|
|
|
|
llm_client = AsyncOpenAI(
|
|
|
api_key=LlmBaseConfig.api_key,
|
|
|
base_url=LlmBaseConfig.base_url,
|
|
|
)
|
|
|
|
|
|
emb_client = AsyncOpenAI(
|
|
|
api_key=EmbBaseConfig.api_key,
|
|
|
base_url=EmbBaseConfig.base_url,
|
|
|
)
|
|
|
|
|
|
async def generation_rule(prompt):
|
|
|
response = await llm_client.chat.completions.create(
|
|
|
model=LlmBaseConfig.model_name,
|
|
|
messages=prompt,
|
|
|
n = 1,
|
|
|
stream = False,
|
|
|
temperature=0.0,
|
|
|
max_tokens=600,
|
|
|
top_p = 1.0,
|
|
|
frequency_penalty=0.0,
|
|
|
presence_penalty=0.0,
|
|
|
# stop = ["Q:"]
|
|
|
)
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
async def generation_vector(text):
|
|
|
response = await emb_client.embeddings.create(
|
|
|
model=EmbBaseConfig.model_name, # 替换为实际的向量模型名称
|
|
|
input=text,
|
|
|
encoding_format="float"
|
|
|
)
|
|
|
return response.data[0].embedding
|
|
|
|
|
|
async def rerank_documents(query, documents, top_n=None):
|
|
|
"""
|
|
|
调用 rerank 模型对文档进行重排序
|
|
|
:param query: 查询语句
|
|
|
:param documents: 文档列表
|
|
|
:param top_n: 返回前n个结果,默认返回所有
|
|
|
:return: 重排序后的结果
|
|
|
"""
|
|
|
if top_n is None:
|
|
|
top_n = min(4, len(documents))
|
|
|
|
|
|
documents_str = "\n".join([f"{i+1}. {doc}" for i, doc in enumerate(documents)])
|
|
|
messages = [
|
|
|
{
|
|
|
"role": "system",
|
|
|
"content": "你是一个文档相关性排序助手。请根据查询语句与文档的相关性对文档进行排序,只返回文档序号的排序结果,如:[2, 1, 3]"
|
|
|
},
|
|
|
{
|
|
|
"role": "user",
|
|
|
"content": f"查询:{query}\n\n候选文档:\n{documents_str}\n\n请按相关性从高到低排序,只返回序号列表:"
|
|
|
}
|
|
|
]
|
|
|
|
|
|
response = await llm_client.chat.completions.create(
|
|
|
model=LlmBaseConfig.model_name, # 使用已知可用的模型
|
|
|
messages=messages,
|
|
|
temperature=0.0,
|
|
|
max_tokens=100
|
|
|
)
|
|
|
|
|
|
sorted_results = response.choices[0].message.content
|
|
|
|
|
|
return sorted_results
|
|
|
|
|
|
# 将顶层await包装在异步函数中
|
|
|
async def main():
|
|
|
# prompt = [
|
|
|
# {"role": "system", "content": "你是一个助手"},
|
|
|
# {"role": "user", "content": "你是谁"}
|
|
|
# ]
|
|
|
# ans = await generation_rule(prompt)
|
|
|
|
|
|
# ans = await generation_vector("hello world")
|
|
|
|
|
|
documents = [
|
|
|
"人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
|
|
|
"机器学习是人工智能的一个分支,主要研究计算机如何从数据中学习规律,并利用这些规律对未知数据进行预测。",
|
|
|
"深度学习是机器学习的一个子集,它模仿人脑的工作方式,通过多层神经网络来学习数据的特征。"
|
|
|
]
|
|
|
ans = await rerank_documents(query="机器学习", documents=documents, top_n=2)
|
|
|
|
|
|
print(ans)
|
|
|
return ans
|
|
|
|
|
|
# 使用asyncio运行异步函数
|
|
|
import asyncio
|
|
|
if __name__ == "__main__":
|
|
|
asyncio.run(main())
|