| import hashlib |
| import pickle |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from rank_bm25 import BM25Okapi |
| from sentence_transformers import SentenceTransformer |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| from get_documents import load_and_process_data |
| from parse_documents import process_documents |
| from lemmatizer import RussianLemmatizer |
|
|
|
|
| def normalize_array(arr): |
| min_val = np.min(arr) |
| max_val = np.max(arr) |
| return (arr - min_val) / (max_val - min_val) |
|
|
|
|
| class Retrieval: |
| """ |
| Структура хранения данных: |
| ============================ |
| |
| 1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df): |
| ┌──────────────────────┬─────────────────────────────────┐ |
| │ Колонка │ Описание │ |
| ├──────────────────────┼─────────────────────────────────┤ |
| │ paragraph_id │ Уникальный ID параграфа │ |
| │ summary │ Название документа/раздела │ |
| │ start_year │ Год начала периода │ |
| │ end_year │ Год окончания периода │ |
| │ text │ Текст │ |
| │ document_id │ Ссылка на исходный документ │ |
| └──────────────────────┴─────────────────────────────────┘ |
| |
| 2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df): |
| ┌──────────────────────┬─────────────────────────────────┐ |
| │ Колонка │ Описание │ |
| ├──────────────────────┼─────────────────────────────────┤ |
| │ chunk_id │ Уникальный ID чанка │ |
| │ paragraph_id │ Foreign key на параграф │ |
| │ text │ Исходный текст чанка │ |
| │ lemmatized_text │ Лемматизированный текст │ |
| │ (embeddings) │ (будет добавлено в будущем) │ |
| └──────────────────────┴─────────────────────────────────┘ |
| |
| 3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()): |
| Комбинирует оба датафрейма через JOIN по paragraph_id. |
| Содержит все колонки обоих датафреймов. |
| Используется для поиска и фильтрации. |
| """ |
| |
| def __init__(self, use_gpu: bool = False, use_cache: bool = True): |
| print("Инициализация RAG системы...") |
| self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" |
| self.use_cache = use_cache |
| |
| |
| self.cache_dir = Path('.cache') |
| if self.use_cache: |
| self.cache_dir.mkdir(exist_ok=True) |
| |
| |
| print(" Инициализация лемматизатора...") |
| self.lemmatizer = RussianLemmatizer() |
| |
| |
| print("1. Загрузка данных из JSON...") |
| self.documents, self.docs_names = load_and_process_data() |
| |
| print(f" Загружено {len(self.documents)} сообщений") |
| |
| |
| self.paragraphs_df, self.chunks_df = process_documents(self.documents) |
| |
| |
| print("2. Лемматизация текстов (с кэшированием)...") |
| self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text']) |
| |
| |
| |
| self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp") |
| |
| self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True) |
|
|
| print("RAG система готова к использованию") |
|
|
|
|
| |
|
|
| def _load_cache(self) -> dict: |
| """ |
| Загружает кэш лемматизации из файловой системы. |
| |
| Returns: |
| dict: {text_hash -> lemmatized_tokens} |
| """ |
| cache_file = self.cache_dir / 'lemmatization_cache.pkl' |
| |
| if cache_file.exists(): |
| try: |
| with open(cache_file, 'rb') as f: |
| cache = pickle.load(f) |
| print(f" ✓ Кэш загружен ({len(cache)} записей)") |
| return cache |
| except Exception as e: |
| print(f" ⚠ Ошибка при загрузке кэша: {e}") |
| return {} |
| return {} |
| |
| def _lemmatize_with_cache(self, texts: list[str]) -> list: |
| """ |
| Лемматизирует тексты с использованием кэша. |
| Проверяет хэши текстов - если хэш совпадает с кэшированным, |
| использует кэшированный результат. Иначе перелемматизирует. |
| |
| Args: |
| texts: Список текстов для лемматизации |
| |
| Returns: |
| list: Лемматизированные тексты |
| """ |
| if not self.use_cache: |
| |
| return [self.lemmatizer.tokenize_text(text) for text in texts] |
| |
| |
| cache = self._load_cache() |
| text_hashes = {} |
| results = [] |
| needs_save = False |
| |
| for text in texts: |
| text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() |
| text_hashes[text] = text_hash |
| |
| if text_hash in cache: |
| |
| results.append(cache[text_hash]) |
| else: |
| |
| lemmatized = self.lemmatizer.tokenize_text(text) |
| results.append(lemmatized) |
| cache[text_hash] = lemmatized |
| needs_save = True |
| |
| |
| if needs_save: |
| with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f: |
| pickle.dump(cache, f) |
| print(f" ✓ Кэш сохранён ({len(cache)} записей)") |
| |
| return results |
| |
| def semantic_search(self, query: str) -> torch.Tensor: |
| |
| query_embedding = torch.tensor(self.embedder.encode_query(query)) |
| semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu() |
| return semantic_scores |
| |
| def bm25_search(self, query: str) -> np.ndarray: |
| """BM25 поиск, используя лемматизированные чанки. |
| |
| Args: |
| query: Текст запроса |
| |
| Returns: |
| np.ndarray: Скоры для каждого абзаца (не предложения!) |
| """ |
| bm25 = BM25Okapi(self.chunks_df['lemmatized_text']) |
| tokenized_query = self.lemmatizer.tokenize_text(query) |
| sentences_scores = bm25.get_scores(tokenized_query) |
| df = self.chunks_df['paragraph_id'].to_frame().copy() |
| df['score'] = sentences_scores |
| paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0) |
| return paragraph_scores |
| |
| def search(self, query: str) -> None: |
| bm25_scores = self.bm25_search(query) |
| semantic_scores = self.semantic_search(query).numpy() |
| bm25_scores = normalize_array(bm25_scores) |
| return semantic_scores + 1.0 * bm25_scores |