From 8048e839795361eb4cfe0349814b84c7fb01d9c0 Mon Sep 17 00:00:00 2001 From: "liuchangsheng@wisdomidata.com" Date: Mon, 16 Jun 2025 10:09:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Dify=E3=80=91=20=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=E5=85=B3=E9=94=AE=E8=AF=8D=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/extensions/utils/search_tool.py | 114 ++++++++++ api/pyproject.toml | 2 + api/services/ext/dataset_ext_service.py | 194 +++++++++++++++-- api/services/ext/stopwords.txt | 264 ++++++++++++++++++++++++ 4 files changed, 557 insertions(+), 17 deletions(-) create mode 100644 api/extensions/utils/search_tool.py create mode 100644 api/services/ext/stopwords.txt diff --git a/api/extensions/utils/search_tool.py b/api/extensions/utils/search_tool.py new file mode 100644 index 0000000000..8ace63f483 --- /dev/null +++ b/api/extensions/utils/search_tool.py @@ -0,0 +1,114 @@ +import difflib +from collections import defaultdict, Counter +import itertools +import re + +class TextIndex: + def __init__(self, text_text, index): + self.text_text = text_text + self.index = index + + def to_dict(self): + return { + "text_text": self.text_text, + "index": self.index, + } + +def find_all_occurrences(source: str, target: str): + return [match.start() for match in re.finditer(re.escape(source), target)] + +def get_text_max_score(search_texts: list[str],search_index: int, pos_map,root_list:list[TextIndex], groups:list[list[TextIndex]]): + + if len(search_texts) == search_index and len(root_list) > 0: + groups.append(root_list) + return + search_text = search_texts[search_index] + text_indexs = pos_map[search_text] + next_index = search_index + 1 + if text_indexs: + new_root_list = root_list[:] + for t_idx,text_index in enumerate(text_indexs): + this_root_list = [] + if t_idx > 0: + this_root_list=new_root_list[:] + else: + this_root_list = root_list + this_root_list.append(text_index) + get_text_max_score(search_texts=search_texts,search_index=next_index,pos_map=pos_map,root_list=this_root_list,groups=groups) + else: + get_text_max_score(search_texts=search_texts,search_index=next_index,pos_map=pos_map,root_list=root_list,groups=groups) + +# def get_text_index_score(text_indexs: list[TextIndex],search_texts: list[str]): +# # 去掉一个最后面的 +# # 去掉一个最前面的 + +def get_text_index_score(text_indexs: list[TextIndex],search_texts: list[str]): + + deduct_points = 0 + search_text_count = len("".join(search_texts)) + text_count = 0 + for idx,text_index in enumerate(text_indexs): + text_count += len(text_index.text_text) + if idx < len(text_indexs) - 1: + next_text_index = text_indexs[idx + 1] + t_score = 0 + if next_text_index.index > text_index.index: + t_score = next_text_index.index - text_index.index - len(text_index.text_text) - 1 + else: + t_score = text_index.index - next_text_index.index - len(next_text_index.text_text) + t_score = abs(t_score) + deduct_points += t_score + if deduct_points > 50: + return 0 + deduct_points += (search_text_count - text_count) * 3 + return 100 - deduct_points + +def get_full_search_text_max_score(search_texts: list[str], target_text: str) -> (int, list[TextIndex]): + import pdb; pdb.set_trace() + # 1. 建立 source 中每个字符的索引映射 + # pos_map = defaultdict(list) + text_index_groups:list[list[TextIndex]] = [] + for search_text in search_texts: + idxs = find_all_occurrences(source=search_text, target=target_text) + text_indexs = [TextIndex(text_text=search_text,index=idx) for idx in idxs] + # pos_map[search_text].extend(text_indexs) + text_index_groups.append(text_indexs) + + import pdb; pdb.set_trace() + # groups:list[list[TextIndex]] = [] + max_score = -100000 + max_index_list:list[TextIndex] + for text_index_s in itertools.product(*text_index_groups): + text_index_list:list[TextIndex] = list(text_index_s) + score_ = get_text_index_score(text_indexs=text_index_list,search_texts=search_texts) + if score_ < 50: + continue + if score_ > max_score: + max_score = score_ + max_index_list = text_index_list + + # get_text_max_score(search_texts=search_texts,search_index=0,pos_map=pos_map,root_list=[], groups=groups) + # max_index_list:list[TextIndex] = [] + # max_score = -100000 + # import pdb; pdb.set_trace() + # for g_list in groups: + # score_,milist = get_text_index_score(text_indexs=g_list,search_texts=search_texts) + # if score_ > max_score: + # max_score = score_ + # max_index_list = g_list + # print("score_",score_) + # texts = [] + # for text_index in g_list: + # t_len = len(text_index.text_text) + # t_idx = text_index.index + # text = target_text[t_idx : t_idx+t_len] + # texts.append(text) + # print("--------------------------") + # print("".join(texts)) + import pdb; pdb.set_trace() + return (max_score,max_index_list) + +if __name__ == "__main__": + search_texts=["湖人","阵容"] + score, max_index_list =get_full_search_text_max_score(search_texts=search_texts, source="所以,**严格讲,詹姆斯在湖人确实拥有超级巨星(戴维斯),但不像热火三巨头那样多核并立。**更多时候,他还是湖人阵容的绝对核心和领袖。") + print(score, len(max_index_list)) diff --git a/api/pyproject.toml b/api/pyproject.toml index 9d41ea502f..6cc021b0c3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -104,6 +104,8 @@ package = false # Required for development and running tests ############################################################ dev = [ + "fuzzywuzzy~=0.18.0", + "python-Levenshtein~=0.12.2", "coverage~=7.2.4", "dotenv-linter~=0.5.0", "faker~=32.1.0", diff --git a/api/services/ext/dataset_ext_service.py b/api/services/ext/dataset_ext_service.py index 148b98bb8f..4e7f412d67 100644 --- a/api/services/ext/dataset_ext_service.py +++ b/api/services/ext/dataset_ext_service.py @@ -21,6 +21,27 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon from sqlalchemy import text, bindparam,select,func from collections import defaultdict from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.engine import Row +import jieba +import jieba.analyse +import difflib +from extensions.utils.search_tool import get_full_search_text_max_score +import json + +class Keywords: + def __init__(self, texts, main_texts, search_texts, search_sql): + self.texts = texts + self.main_texts = main_texts + self.search_texts=search_texts + self.search_sql=search_sql + + def to_dict(self): + return { + "texts": self.texts, + "main_texts": self.main_texts, + "search_texts": self.search_texts, + "search_sql": self.search_sql, + } class DatasetExtService: resource_type = "dataset" @@ -201,32 +222,36 @@ class DocumentExtService: break return next_segment - def get_full_search_data(dataset_names: list[str], tenant_id : str, query_text: str, file_ids: str) -> list[dict]: - + import pdb; pdb.set_trace() if not file_ids: return [] datasets = db.session.query(Dataset).filter(Dataset.name.in_(dataset_names),Dataset.tenant_id == tenant_id).all() dataset_ids = [dataset.id for dataset in datasets] # 精准查询的向量片段 - fetch_segments = DocumentExtService.get_full_search_segments(dataset_ids=dataset_ids,query_text=query_text) - - search_datas = [] - for segment in fetch_segments: - search_data = { - "title": segment.document_name, - "content": segment.segment_content, - "doc_metadata": segment.doc_metadata, - "query": query_text - } - search_datas.append(search_data) - - search_datas = DocumentExtService.filter_by_file_ids(search_datas=search_datas, file_ids=file_ids) - return search_datas + # fetch_segments = DocumentExtService.get_full_search_segments(dataset_ids=dataset_ids,query_text=query_text) + keywords = DocumentExtService.get_keywords(query_text=query_text) + segments_rows, document_rows = DocumentExtService.get_keyword_search_segments( + dataset_ids=dataset_ids, + keywords=keywords, + ) + # 过滤文件ID + segments_rows = DocumentExtService.filter_rows_by_file_ids(segments_rows, file_ids) + # 过滤文件ID + document_rows = DocumentExtService.filter_rows_by_file_ids(document_rows, file_ids) + import pdb; pdb.set_trace() + # 计算分值高的数据 + segment_datas = DocumentExtService.get_full_search_segments_by_score( + keywords=keywords, + query_text=query_text, + segments_rows=segments_rows, + document_rows=document_rows + ) + return segment_datas def get_full_search_segments(dataset_ids: list[str], query_text: str): @@ -279,9 +304,144 @@ class DocumentExtService: fetch_segments.append(segment_list[1]) return fetch_segments + def get_keywords(query_text: str) -> Keywords: + # 分词器分词关键词 + keyword_texts = list(jieba.cut(query_text)) + import pdb; pdb.set_trace() + # 判断关键词的长度 + jieba.analyse.set_stop_words("services/ext/stopwords.txt") + # def get_text(): + # return text + # 提取关键词,默认 topK=30,withWeight=True + main_keywords_texts__ = jieba.analyse.extract_tags(query_text, topK=200, withWeight=False) + + main_keywords_texts = [] + for text in keyword_texts: + if text in main_keywords_texts__: + main_keywords_texts.append(text) + + keyword_len = len(main_keywords_texts) + main_keywords_len = 0 + # 提取80% + if keyword_len > 2: + main_keywords_len = int(keyword_len * 0.8) + else: + main_keywords_len = keyword_len + + main_keywords_len = len(main_keywords_texts) if main_keywords_len > len(main_keywords_texts) else main_keywords_len + # 得出最关键的分词 + search_keywords_texts = main_keywords_texts[:main_keywords_len + 1] + + search_sql = ' & '.join(search_keywords_texts) + # 按照最关键的分词查询 + keywords = Keywords( + texts=main_keywords_texts, + main_texts=main_keywords_texts, + search_texts=search_keywords_texts, + search_sql=search_sql + ) + return keywords + + def get_keyword_search_segments(dataset_ids: list[str], + keywords: Keywords + ) -> (list[Row],list[Row]): + + sql = text(f""" + SELECT s.id segment_id, s.document_id, s.content segment_content, d.name document_name,d.doc_metadata + FROM document_segments s + left join documents d on d.id = s.document_id + WHERE to_tsvector('chinese', s.content) @@ to_tsquery(:keywords) and d.dataset_id::text = ANY(:dataset_ids) + """) + + segments_rows = db.session.execute(sql, {"keywords": keywords.search_sql, "dataset_ids" : dataset_ids}).fetchall() + + sql = text(""" + SELECT d.id AS document_id, + d.name AS document_name, + s.id AS segment_id, + s.content AS segment_content, + d.doc_metadata + FROM documents d + JOIN ( + SELECT s1.* + FROM document_segments s1 + INNER JOIN ( + SELECT document_id, MIN(position) AS first_position + FROM document_segments + GROUP BY document_id + ) s2 ON s1.document_id = s2.document_id AND s1.position = s2.first_position + ) s ON d.id = s.document_id + WHERE to_tsvector('chinese', d.name) @@ to_tsquery(:keywords) and d.dataset_id::text = ANY(:dataset_ids_) + """) + document_rows = db.session.execute(sql, {"keywords": keywords.search_sql, "dataset_ids_" : dataset_ids}).fetchall() + return segments_rows, document_rows + + # 计算分值 + def get_full_search_segments_by_score(keywords : Keywords, + query_text : str, + segments_rows: list[Row], + document_rows: list[Row]) -> list[dict]: + + segment_datas = [] + for document in document_rows: + score, s_list = get_full_search_text_max_score(search_texts=keywords.main_texts,target_text=document.document_name) + segment_data = { + "document_id" : str(document.document_id), + "title": document.document_name, + "content": document.segment_content, + "doc_metadata": document.doc_metadata, + "query": query_text, + "score": score, + } + segment_datas.append(segment_data) + + for segment in segments_rows: + score,s_list = get_full_search_text_max_score(search_texts=keywords.main_texts,target_text=segment.segment_content) + segment_data = { + "document_id" : str(segment.document_id), + "title": segment.document_name, + "content": segment.segment_content, + "doc_metadata": segment.doc_metadata, + "query": query_text, + "score": score, + } + segment_datas.append(segment_data) + + grouped = defaultdict(list) + + for segment in segment_datas: + grouped[segment["document_id"]].append(segment) + + max_score_segments = [] + # 遍历 grouped + for document_id, segment_list in grouped.items(): + if len(segment_list) == 1: + max_score_segments.append(segment_list[0]) + else: + max_segment = max(segment_list[1:], key=lambda x: x['score']) + max_score_segments.append(max_segment) + + # 按照分值排序 + max_score_segments = sorted(max_score_segments, key=lambda x: x['score'], reverse=True) + import pdb; pdb.set_trace() + return max_score_segments + + def filter_rows_by_file_ids(search_datas: list[Row], file_ids: str) -> list[Row]: + file_id_list = file_ids.split(",") + + filter_rows = [] + for item in search_datas: + doc_metadata = item.doc_metadata + if not doc_metadata: + doc_metadata = {} + if doc_metadata["file_id"] and doc_metadata["file_id"] in file_id_list: + filter_rows.append(item) + + return filter_rows + def filter_by_file_ids(search_datas: list[dict], file_ids: str) -> list[dict]: file_id_list = file_ids.split(",") return [ item for item in search_datas - if item.get("doc_metadata", {}).get("file_id") in file_ids + if item.get("doc_metadata", {}).get("file_id") in file_id_list ] diff --git a/api/services/ext/stopwords.txt b/api/services/ext/stopwords.txt new file mode 100644 index 0000000000..191fb7265f --- /dev/null +++ b/api/services/ext/stopwords.txt @@ -0,0 +1,264 @@ +的 +了 +和 +是 +我 +也 +就 +都 +而 +及 +与 +着 +或 +一个 +没有 +我们 +你们 +他们 +她们 +它们 +自己 +这 +那 +这些 +那些 +它 +被 +在 +对于 +因为 +所以 +如果 +然后 +而且 +并且 +并 +但是 +不过 +不是 +而是 +还有 +还 +已 +已经 +正在 +非常 +很 +较 +更 +最 +吧 +啊 +呀 +嘛 +呢 +么 +吗 +哦 +恩 +呃 +咯 +啊呀 +啥 +哈 +啊哈 +啦 +咱 +什么 +多少 +几 +多 +你 +我 +他 +她 +它 +咱们 +此 +其 +某 +某个 +某些 +每 +每个 +各 +个 +等 +等于 +以及 +其中 +从而 +因此 +除此之外 +据此 +比如 +例如 +比如说 +譬如 +比如说的 +说 +要 +来 +去 +把 +被 +给 +使 +令 +让 +令得 +所 +之 +之所以 +以 +以便 +以免 +以至 +以致 +以内 +以来 +之后 +之前 +之后 +期间 +前后 +上下 +以上 +以下 +左右 +当时 +当年 +当下 +眼下 +马上 +立刻 +即将 +刚刚 +刚才 +后来 +曾经 +仍然 +依然 +一直 +一直到 +尚且 +甚至 +最终 +总是 +总共 +其实 +本来 +原来 +明显 +确实 +大概 +大约 +差不多 +也许 +可能 +估计 +基本 +尤其 +尽管 +虽然 +然而 +然而却 +不过 +但是 +还是 +但是呢 +毕竟 +同时 +并不是 +并非 +并无 +未必 +尚未 +不如 +不然 +以外 +之外 +其中 +而后 +而今 +而后 +除此 +除此之外 +否则 +万一 +万万 +若 +若是 +假如 +假设 +要是 +若非 +非得 +无非 +何况 +况且 +再说 +试问 +只要 +除非 +只有 +宁愿 +宁可 +宁肯 +不如 +不妨 +不必 +务必 +务须 +尚需 +无需 +都 +谁 +啥 +哪 +哪儿 +哪里 +怎样 +怎么 +怎么样 +何时 +几时 +多少 +几多 +多么 +啥也 +哪怕 +纵然 +即使 +假使 +便 +则 +即 +乃 +虽 +虽说 +且 +以致于 +为止 +为此 +为的是 +不管 +无论 +任凭 +凡是 +凡 +既然 +既 +既已 +既然如此 +说实话 +说到底 +也就是说 +一句话 +总之 +总的来说 +总而言之 +总的说来 +换句话说 +话说