improve the consistancy

pull/12311/head
Dr. Kiji 1 year ago
parent 610d069b69
commit 75dd8677b9

@ -1,7 +1,8 @@
import json import json
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Dict, List, Optional, Set
from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.datasource.keyword.mecab.config import MeCabConfig from core.rag.datasource.keyword.mecab.config import MeCabConfig
@ -10,32 +11,28 @@ from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KeywordProcessorError(Exception): class KeywordProcessorError(Exception):
"""Base error for keyword processing.""" """Base error for keyword processing."""
pass pass
class KeywordExtractionError(KeywordProcessorError): class KeywordExtractionError(KeywordProcessorError):
"""Error during keyword extraction.""" """Error during keyword extraction."""
pass pass
class KeywordStorageError(KeywordProcessorError): class KeywordStorageError(KeywordProcessorError):
"""Error during storage operations.""" """Error during storage operations."""
pass pass
class SetEncoder(json.JSONEncoder): class SetEncoder(json.JSONEncoder):
"""JSON encoder that handles sets.""" """JSON encoder that handles sets."""
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
return list(obj) return list(obj)
@ -48,164 +45,283 @@ class MeCab(BaseKeyword):
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
super().__init__(dataset) super().__init__(dataset)
self._config = MeCabConfig() self._config = MeCabConfig()
self._keyword_handler = None self._keyword_handler: MeCabKeywordTableHandler = MeCabKeywordTableHandler()
self._init_handler() self._init_handler()
def _init_handler(self): def _init_handler(self) -> None:
"""Initialize MeCab handler with configuration.""" """Initialize MeCab handler with configuration."""
try: try:
self._keyword_handler = MeCabKeywordTableHandler( self._keyword_handler = MeCabKeywordTableHandler(
dictionary_path=self._config.dictionary_path, user_dictionary_path=self._config.user_dictionary_path dictionary_path=self._config.dictionary_path,
user_dictionary_path=self._config.user_dictionary_path
) )
if self._config.pos_weights: if self._config.pos_weights:
self._keyword_handler.pos_weights = self._config.pos_weights self._keyword_handler.pos_weights = self._config.pos_weights
self._keyword_handler.min_score = self._config.score_threshold self._keyword_handler.min_score = self._config.score_threshold
except Exception as e: except Exception as e:
logger.exception("Failed to initialize MeCab handler") logger.exception("Failed to initialize MeCab handler")
raise KeywordProcessorError(f"MeCab initialization failed: {str(e)}") raise KeywordProcessorError("MeCab initialization failed: {}".format(str(e)))
def create(self, texts: list[Document], **kwargs) -> BaseKeyword: def create(self, texts: List[Document], **kwargs: Any) -> BaseKeyword:
"""Create keyword index for documents.""" """Create keyword index for documents."""
lock_name = f"keyword_indexing_lock_{self.dataset.id}" if not texts:
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = self._keyword_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, text.metadata["doc_id"], list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
return self return self
def add_texts(self, texts: list[Document], **kwargs): lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
"""Add new texts to existing index.""" try:
lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600):
with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table()
keyword_table = self._get_dataset_keyword_table() if keyword_table is None:
keywords_list = kwargs.get("keywords_list") keyword_table = {}
for text in texts:
if not text.page_content or not text.metadata or "doc_id" not in text.metadata:
logger.warning("Skipping invalid document: {}".format(text))
continue
for i, text in enumerate(texts): try:
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = self._keyword_handler.extract_keywords( keywords = self._keyword_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk text.page_content, self._config.max_keywords_per_chunk
) )
else: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keywords = self._keyword_handler.extract_keywords( keyword_table = self._add_text_to_keyword_table(
text.page_content, self._config.max_keywords_per_chunk keyword_table, text.metadata["doc_id"], list(keywords)
) )
except Exception as e:
logger.exception("Failed to process document: {}".format(text.metadata.get("doc_id")))
raise KeywordExtractionError("Failed to extract keywords: {}".format(str(e)))
if text.metadata is not None: try:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table)
keyword_table = self._add_text_to_keyword_table( except Exception as e:
keyword_table or {}, text.metadata["doc_id"], list(keywords) logger.exception("Failed to save keyword table")
) raise KeywordStorageError("Failed to save keyword table: {}".format(str(e)))
self._save_dataset_keyword_table(keyword_table) except Exception as e:
if not isinstance(e, (KeywordExtractionError, KeywordStorageError)):
logger.exception("Unexpected error during keyword indexing")
raise KeywordProcessorError("Keyword indexing failed: {}".format(str(e)))
raise
return self
def add_texts(self, texts: List[Document], **kwargs: Any) -> None:
"""Add new texts to existing index."""
if not texts:
return
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
try:
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
keyword_table = {}
keywords_list = kwargs.get("keywords_list")
for i, text in enumerate(texts):
if not text.page_content or not text.metadata or "doc_id" not in text.metadata:
logger.warning("Skipping invalid document: {}".format(text))
continue
try:
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = self._keyword_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
else:
keywords = self._keyword_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table, text.metadata["doc_id"], list(keywords)
)
except Exception as e:
logger.exception("Failed to process document: {}".format(text.metadata.get("doc_id")))
continue
try:
self._save_dataset_keyword_table(keyword_table)
except Exception as e:
logger.exception("Failed to save keyword table")
raise KeywordStorageError("Failed to save keyword table: {}".format(str(e)))
except Exception as e:
if not isinstance(e, KeywordStorageError):
logger.exception("Unexpected error during keyword indexing")
raise KeywordProcessorError("Keyword indexing failed: {}".format(str(e)))
raise
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""Check if text exists in index.""" """Check if text exists in index."""
if not id:
return False
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
if keyword_table is None: if keyword_table is None:
return False return False
return id in set.union(*keyword_table.values()) if keyword_table else False return id in set.union(*keyword_table.values()) if keyword_table else False
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: List[str]) -> None:
"""Delete texts by IDs.""" """Delete texts by IDs."""
lock_name = f"keyword_indexing_lock_{self.dataset.id}" if not ids:
with redis_client.lock(lock_name, timeout=600): return
keyword_table = self._get_dataset_keyword_table()
if keyword_table is not None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) try:
self._save_dataset_keyword_table(keyword_table) with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
if keyword_table is not None:
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
except Exception as e:
logger.exception("Failed to delete documents")
raise KeywordStorageError("Failed to delete documents: {}".format(str(e)))
def delete(self) -> None: def delete(self) -> None:
"""Delete entire index.""" """Delete entire index."""
lock_name = f"keyword_indexing_lock_{self.dataset.id}" lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600): try:
dataset_keyword_table = self.dataset.dataset_keyword_table with redis_client.lock(lock_name, timeout=600):
if dataset_keyword_table: dataset_keyword_table = self.dataset.dataset_keyword_table
db.session.delete(dataset_keyword_table) if dataset_keyword_table:
db.session.commit() db.session.delete(dataset_keyword_table)
if dataset_keyword_table.data_source_type != "database": db.session.commit()
file_key = f"keyword_files/{self.dataset.tenant_id}/{self.dataset.id}.txt" if dataset_keyword_table.data_source_type != "database":
storage.delete(file_key) file_key = os.path.join("keyword_files", self.dataset.tenant_id, self.dataset.id + ".txt")
storage.delete(file_key)
except Exception as e:
logger.exception("Failed to delete index")
raise KeywordStorageError("Failed to delete index: {}".format(str(e)))
def search(self, query: str, **kwargs: Any) -> list[Document]: def search(self, query: str, **kwargs: Any) -> List[Document]:
"""Search documents using keywords.""" """Search documents using keywords."""
keyword_table = self._get_dataset_keyword_table() if not query:
k = kwargs.get("top_k", 4) return []
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) try:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get("top_k", 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
if not sorted_chunk_indices:
return []
documents = []
for chunk_index in sorted_chunk_indices:
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
)
.first()
)
documents = [] if segment:
for chunk_index in sorted_chunk_indices: documents.append(
segment = ( Document(
db.session.query(DocumentSegment) page_content=segment.content,
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) metadata={
.first() "doc_id": chunk_index,
) "doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
)
return documents
except Exception as e:
logger.exception("Failed to search documents")
raise KeywordProcessorError("Search failed: {}".format(str(e)))
if segment: def _get_dataset_keyword_table(self) -> Optional[Dict[str, Set[str]]]:
documents.append( """Get keyword table from storage."""
Document( try:
page_content=segment.content, dataset_keyword_table = self.dataset.dataset_keyword_table
metadata={ if dataset_keyword_table:
"doc_id": chunk_index, keyword_table_dict = dataset_keyword_table.keyword_table_dict
"doc_hash": segment.index_node_hash, if keyword_table_dict:
"document_id": segment.document_id, return dict(keyword_table_dict["__data__"]["table"])
"dataset_id": segment.dataset_id, else:
# Create new dataset keyword table if it doesn't exist
from configs import dify_config
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table="",
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
}, },
cls=SetEncoder,
) )
) db.session.add(dataset_keyword_table)
db.session.commit()
return documents return {}
except Exception as e:
logger.exception("Failed to get keyword table")
raise KeywordStorageError("Failed to get keyword table: {}".format(str(e)))
def _get_dataset_keyword_table(self) -> Optional[dict]: def _save_dataset_keyword_table(self, keyword_table: Dict[str, Set[str]]) -> None:
"""Get keyword table from storage."""
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
return {}
def _save_dataset_keyword_table(self, keyword_table):
"""Save keyword table to storage.""" """Save keyword table to storage."""
if keyword_table is None:
raise ValueError("Keyword table cannot be None")
table_dict = { table_dict = {
"__type__": "keyword_table", "__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
} }
dataset_keyword_table = self.dataset.dataset_keyword_table try:
data_source_type = dataset_keyword_table.data_source_type dataset_keyword_table = self.dataset.dataset_keyword_table
if not dataset_keyword_table:
raise KeywordStorageError("Dataset keyword table not found")
if data_source_type == "database": data_source_type = dataset_keyword_table.data_source_type
dataset_keyword_table.keyword_table = json.dumps(table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = f"keyword_files/{self.dataset.tenant_id}/{self.dataset.id}.txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(table_dict, cls=SetEncoder).encode("utf-8"))
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: if data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = os.path.join("keyword_files", self.dataset.tenant_id, self.dataset.id + ".txt")
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(table_dict, cls=SetEncoder).encode("utf-8"))
except Exception as e:
logger.exception("Failed to save keyword table")
raise KeywordStorageError("Failed to save keyword table: {}".format(str(e)))
def _add_text_to_keyword_table(
self, keyword_table: Dict[str, Set[str]], id: str, keywords: List[str]
) -> Dict[str, Set[str]]:
"""Add text keywords to table.""" """Add text keywords to table."""
if not id or not keywords:
return keyword_table
for keyword in keywords: for keyword in keywords:
if keyword not in keyword_table: if keyword not in keyword_table:
keyword_table[keyword] = set() keyword_table[keyword] = set()
keyword_table[keyword].add(id) keyword_table[keyword].add(id)
return keyword_table return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict: def _delete_ids_from_keyword_table(
self, keyword_table: Dict[str, Set[str]], ids: List[str]
) -> Dict[str, Set[str]]:
"""Delete IDs from keyword table.""" """Delete IDs from keyword table."""
if not keyword_table or not ids:
return keyword_table
node_idxs_to_delete = set(ids) node_idxs_to_delete = set(ids)
keywords_to_delete = set() keywords_to_delete = set()
@ -220,31 +336,127 @@ class MeCab(BaseKeyword):
return keyword_table return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): def _retrieve_ids_by_query(
self, keyword_table: Dict[str, Set[str]], query: str, k: int = 4
) -> List[str]:
"""Retrieve document IDs by query.""" """Retrieve document IDs by query."""
keywords = self._keyword_handler.extract_keywords(query) if not query or not keyword_table:
return []
# Score documents based on matching keywords try:
chunk_indices_count = defaultdict(int) keywords = self._keyword_handler.extract_keywords(query)
keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords_list: # Score documents based on matching keywords
for node_id in keyword_table[keyword]: chunk_indices_count: dict[str, int] = defaultdict(int)
chunk_indices_count[node_id] += 1 keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
sorted_chunk_indices = sorted(chunk_indices_count.keys(), key=lambda x: chunk_indices_count[x], reverse=True) for keyword in keywords_list:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
return sorted_chunk_indices[:k] # Sort by score in descending order
sorted_chunk_indices = sorted(
chunk_indices_count.keys(),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[:k]
except Exception as e:
logger.exception("Failed to retrieve IDs by query")
raise KeywordExtractionError("Failed to retrieve IDs: {}".format(str(e)))
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]) -> None:
"""Update segment keywords in database.""" """Update segment keywords in database."""
document_segment = ( if not dataset_id or not node_id:
db.session.query(DocumentSegment) return
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first() try:
) document_segment = (
db.session.query(DocumentSegment)
if document_segment: .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
document_segment.keywords = keywords .first()
db.session.add(document_segment) )
db.session.commit()
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit()
except Exception as e:
logger.exception("Failed to update segment keywords")
raise KeywordStorageError("Failed to update segment keywords: {}".format(str(e)))
def create_segment_keywords(self, node_id: str, keywords: List[str]) -> None:
"""Create keywords for a single segment.
Args:
node_id: The segment node ID
keywords: List of keywords to add
"""
if not node_id or not keywords:
return
try:
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
except Exception as e:
logger.exception("Failed to create segment keywords")
raise KeywordProcessorError("Failed to create segment keywords: {}".format(str(e)))
def multi_create_segment_keywords(self, pre_segment_data_list: List[Dict[str, Any]]) -> None:
"""Create keywords for multiple segments in batch."""
if not pre_segment_data_list:
return
try:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
keyword_table = {}
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data["segment"]
if not segment:
continue
try:
if pre_segment_data.get("keywords"):
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = self._keyword_handler.extract_keywords(
segment.content, self._config.max_keywords_per_chunk
)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, list(keywords)
)
except Exception as e:
logger.exception("Failed to process segment: {}".format(segment.index_node_id))
continue
self._save_dataset_keyword_table(keyword_table)
except Exception as e:
logger.exception("Failed to create multiple segment keywords")
raise KeywordProcessorError("Failed to create multiple segment keywords: {}".format(str(e)))
def update_segment_keywords_index(self, node_id: str, keywords: List[str]) -> None:
"""Update keywords index for a segment.
Args:
node_id: The segment node ID
keywords: List of keywords to update
"""
if not node_id or not keywords:
return
try:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
except Exception as e:
logger.exception("Failed to update segment keywords index")
raise KeywordStorageError("Failed to update segment keywords index: {}".format(str(e)))

Loading…
Cancel
Save