You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from openai import AsyncOpenAI
from env import LlmBaseConfig, EmbBaseConfig
llm_client = AsyncOpenAI(
api_key=LlmBaseConfig.api_key,
base_url=LlmBaseConfig.base_url,
)
emb_client = AsyncOpenAI(
api_key=EmbBaseConfig.api_key,
base_url=EmbBaseConfig.base_url,
)
async def generation_rule(prompt):
response = await llm_client.chat.completions.create(
model=LlmBaseConfig.model_name,
messages=prompt,
n = 1,
stream = False,
temperature=0.0,
max_tokens=600,
top_p = 1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
# stop = ["Q:"]
)
return response.choices[0].message.content
async def generation_vector(text):
response = await emb_client.embeddings.create(
model=EmbBaseConfig.model_name, # 替换为实际的向量模型名称
input=text,
encoding_format="float"
)
return response.data[0].embedding
async def rerank_documents(query, documents, top_n=None):
"""
调用 rerank 模型对文档进行重排序
:param query: 查询语句
:param documents: 文档列表
:param top_n: 返回前n个结果默认返回所有
:return: 重排序后的结果
"""
if top_n is None:
top_n = min(4, len(documents))
documents_str = "\n".join([f"{i+1}. {doc}" for i, doc in enumerate(documents)])
messages = [
{
"role": "system",
"content": "你是一个文档相关性排序助手。请根据查询语句与文档的相关性对文档进行排序,只返回文档序号的排序结果,如:[2, 1, 3]"
},
{
"role": "user",
"content": f"查询:{query}\n\n候选文档:\n{documents_str}\n\n请按相关性从高到低排序,只返回序号列表:"
}
]
response = await llm_client.chat.completions.create(
model=LlmBaseConfig.model_name, # 使用已知可用的模型
messages=messages,
temperature=0.0,
max_tokens=100
)
sorted_results = response.choices[0].message.content
return sorted_results
# 将顶层await包装在异步函数中
async def main():
# prompt = [
# {"role": "system", "content": "你是一个助手"},
# {"role": "user", "content": "你是谁"}
# ]
# ans = await generation_rule(prompt)
# ans = await generation_vector("hello world")
documents = [
"人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"机器学习是人工智能的一个分支,主要研究计算机如何从数据中学习规律,并利用这些规律对未知数据进行预测。",
"深度学习是机器学习的一个子集,它模仿人脑的工作方式,通过多层神经网络来学习数据的特征。"
]
ans = await rerank_documents(query="机器学习", documents=documents, top_n=2)
print(ans)
return ans
# 使用asyncio运行异步函数
import asyncio
if __name__ == "__main__":
asyncio.run(main())