【Dify】 全文检索的分片-完善算法

pull/22121/head
liuchangsheng@wisdomidata.com 11 months ago
parent 94c7537170
commit e792f213bf

@ -190,22 +190,42 @@ def get_search_keywords_texts_sql(search_keywords_texts:list[str]):
if texts_len == 1: if texts_len == 1:
sql = texts[0] sql = texts[0]
elif texts_len == 2: elif texts_len == 2:
sql = f"{texts[0]} & {texts[1]} | {texts[0]}{texts[1]}" merge_text = merge_strings(texts[0],texts[1])
sql = f"{texts[0]} & {texts[1]} | {merge_text}"
else: else:
sql_texts:list[str] = [] sql_texts:list[str] = []
for idx,text in enumerate(texts): for idx,text in enumerate(texts):
if idx == 0: if idx == 0:
sql_texts.append(f"({text} | {text}{texts[idx + 1]})") merge_text = merge_strings(text,texts[idx + 1])
sql_texts.append(f"({text} | {merge_text})")
elif idx == texts_len - 2: elif idx == texts_len - 2:
sql_texts.append(f"({text} | {text}{texts[idx + 1]} | {texts[idx-1]}{text} & {texts[idx + 1]})") merge_text1 = merge_strings(text,texts[idx + 1])
merge_text2 = merge_strings(texts[idx-1],text)
sql_texts.append(f"({text} | {merge_text1} | {merge_text2} & {texts[idx + 1]})")
elif idx == texts_len - 1: elif idx == texts_len - 1:
sql_texts.append(f"({text} | {texts[idx - 1]}{text})") merge_text = merge_strings(texts[idx-1],text)
sql_texts.append(f"({text} | {merge_text})")
else: else:
sql_texts.append(f"({text} | {text}{texts[idx + 1]} | {texts[idx-1]}{text} & ({texts[idx + 1]} | {texts[idx + 1]}{texts[idx + 2]}))") merge_text1 = merge_strings(text,texts[idx + 1])
merge_text2 = merge_strings(texts[idx-1],text)
merge_text3 = merge_strings(texts[idx + 1],texts[idx + 2])
sql_texts.append(f"({text} | {merge_text1} | {merge_text2} & ({texts[idx + 1]} | {merge_text3}))")
sql = " & ".join(sql_texts) sql = " & ".join(sql_texts)
print(sql) print(sql)
return f"{sql} | {query_sql}" return f"({sql}) | ({query_sql})"
def merge_strings(text1, text2):
max_overlap = 0
min_len = min(len(text1), len(text2))
# 找出最大重叠部分
for i in range(1, min_len + 1):
if text1[-i:] == text2[:i]:
max_overlap = i
# 合并字符串
text = text1 + text2[max_overlap:]
return text
def get_min_search_keywords_texts(texts:list[str]): def get_min_search_keywords_texts(texts:list[str]):
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
@ -250,9 +270,18 @@ def set_full_search_score(query:str,doc_list:list[dict[str, Any]]):
def score(value): def score(value):
return round(20 * math.exp(-0.4 * value), 2) / 100 return round(20 * math.exp(-0.4 * value), 2) / 100
def get_main_keywords_texts_test(query_text: str) -> list[str]:
# 判断关键词的长度
jieba.analyse.set_stop_words("d://stopwords.txt")
# jieba.analyse.set_idf_path("extensions/utils/idfwords.txt")
# 提取关键词,默认 topK=30withWeight=True
main_keywords_texts__ = jieba.analyse.extract_tags(query_text, topK=200, withWeight=False)
return main_keywords_texts__
if __name__ == "__main__": if __name__ == "__main__":
print(score(1)) # print(merge_strings("第二","二层"))
# get_keywords("分类码") get_keywords("我的")
# search_texts=["湖人","阵容"] # search_texts=["湖人","阵容"]
# score, max_index_list =get_full_search_text_max_score(search_texts=search_texts, source="所以,**严格讲,詹姆斯在湖人确实拥有超级巨星(戴维斯),但不像热火三巨头那样多核并立。**更多时候,他还是湖人阵容的绝对核心和领袖。") # score, max_index_list =get_full_search_text_max_score(search_texts=search_texts, source="所以,**严格讲,詹姆斯在湖人确实拥有超级巨星(戴维斯),但不像热火三巨头那样多核并立。**更多时候,他还是湖人阵容的绝对核心和领袖。")
# print(score, len(max_index_list)) # print(score, len(max_index_list))

@ -224,6 +224,7 @@ class DocumentExtService:
print(keywords.__dict__) print(keywords.__dict__)
segments_rows, document_rows = DocumentExtService.get_keyword_search_segments( segments_rows, document_rows = DocumentExtService.get_keyword_search_segments(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
query_text=query_text,
keywords=keywords, keywords=keywords,
) )
# 过滤文件ID # 过滤文件ID
@ -292,19 +293,32 @@ class DocumentExtService:
def get_keyword_search_segments(dataset_ids: list[str], def get_keyword_search_segments(dataset_ids: list[str],
query_text: str,
keywords: Keywords keywords: Keywords
) -> (list[Row],list[Row]): ) -> (list[Row],list[Row]):
params = {"query": query_text, "dataset_ids" : dataset_ids}
where = [" s.content ILIKE :query "]
if keywords.main_texts:
params["keywords"] = keywords.search_sql
where.append(f" to_tsvector('chinese', s.content) @@ to_tsquery(:keywords) ")
where_str = " or ".join(where)
sql = text(f""" sql = text(f"""
SELECT s.id segment_id, s.document_id, s.content segment_content, d.name document_name,d.doc_metadata SELECT s.id segment_id, s.document_id, s.content segment_content, d.name document_name,d.doc_metadata
FROM document_segments s FROM document_segments s
LEFT JOIN documents d ON d.id = s.document_id LEFT JOIN documents d ON d.id = s.document_id
WHERE to_tsvector('chinese', s.content) @@ to_tsquery(:keywords) AND d.dataset_id::text = ANY(:dataset_ids) WHERE ({where_str}) AND d.dataset_id::text = ANY(:dataset_ids)
""") """)
print(sql,keywords.search_sql,dataset_ids[0]) print(sql,keywords.search_sql,dataset_ids[0])
segments_rows = db.session.execute(sql, {"keywords": keywords.search_sql, "dataset_ids" : dataset_ids}).fetchall() segments_rows = db.session.execute(sql, params).fetchall()
sql = text(""" params = {"query": query_text, "dataset_ids" : dataset_ids}
where = [f" d.name ILIKE :query "]
if keywords.main_texts:
params["keywords"] = keywords.search_sql
where.append(f" to_tsvector('chinese', d.name) @@ to_tsquery(:keywords) ")
where_str = " or ".join(where)
sql = text(f"""
SELECT d.id AS document_id, SELECT d.id AS document_id,
d.name AS document_name, d.name AS document_name,
s.id AS segment_id, s.id AS segment_id,
@ -320,9 +334,9 @@ class DocumentExtService:
GROUP BY document_id GROUP BY document_id
) s2 ON s1.document_id = s2.document_id AND s1.position = s2.first_position ) s2 ON s1.document_id = s2.document_id AND s1.position = s2.first_position
) s ON d.id = s.document_id ) s ON d.id = s.document_id
WHERE to_tsvector('chinese', d.name) @@ to_tsquery(:keywords) and d.dataset_id::text = ANY(:dataset_ids_) WHERE ({where_str}) and d.dataset_id::text = ANY(:dataset_ids)
""") """)
document_rows = db.session.execute(sql, {"keywords": keywords.search_sql, "dataset_ids_" : dataset_ids}).fetchall() document_rows = db.session.execute(sql, params).fetchall()
return segments_rows, document_rows return segments_rows, document_rows
# 计算分值 # 计算分值

Loading…
Cancel
Save