import os import json import numpy as np import faiss from typing import List, Dict from sentence_transformers import SentenceTransformer class Retriever: def __init__(self): self.model = None self.index = None self.meta = {} self.embeddings = None self._load_index() def _load_index(self): try: if os.path.exists('data/index/index.faiss') and os.path.exists('data/index/meta.json'): self.index = faiss.read_index('data/index/index.faiss') self.embeddings = np.load('data/index/embeddings.npy') with open('data/index/meta.json', 'r', encoding='utf-8') as f: self.meta = json.load(f) print('Индекс загружен из кэша') else: print('Индекс не найден, будет создан при первом использовании') except Exception as e: print(f'Ошибка загрузки индекса: {e}') def _load_model(self): if self.model is None: try: self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') print('Модель эмбеддингов загружена') except Exception as e: print(f'Ошибка загрузки модели: {e}') raise def _build_index(self, courses: List[Dict]): if not courses: return self._load_model() texts = [] meta_data = {} for i, course in enumerate(courses): text = f"{course.get('name', '')} {course.get('short_desc', '')}" text = text.lower().strip() if len(text) > 220: text = text[:220] texts.append(text) meta_data[i] = course.get('id', str(i)) if not texts: return embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True) embeddings = embeddings.astype(np.float32) faiss.normalize_L2(embeddings) self.index = faiss.IndexFlatIP(embeddings.shape[1]) self.index.add(embeddings) self.embeddings = embeddings self.meta = meta_data self._save_index() def _save_index(self): try: os.makedirs('data/index', exist_ok=True) faiss.write_index(self.index, 'data/index/index.faiss') np.save('data/index/embeddings.npy', self.embeddings) with open('data/index/meta.json', 'w', encoding='utf-8') as f: json.dump(self.meta, f, ensure_ascii=False, indent=2) print('Индекс сохранен') except Exception as e: print(f'Ошибка сохранения индекса: {e}') def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]: if self.index is None: return [] self._load_model() query_embedding = self.model.encode([query.lower().strip()], convert_to_numpy=True) query_embedding = query_embedding.astype(np.float32) faiss.normalize_L2(query_embedding) scores, indices = self.index.search(query_embedding, k) results = [] for score, idx in zip(scores[0], indices[0]): if score >= threshold and idx in self.meta: course_id = self.meta[idx] results.append({ 'course_id': course_id, 'score': float(score) }) return results def build_or_load_index(self, courses: List[Dict] = None): if self.index is None and courses: print('Создание индекса...') self._build_index(courses) elif self.index is None: print('Индекс не найден и данные не предоставлены') def get_embedding_dim(self) -> int: if self.embeddings is not None: return self.embeddings.shape[1] return 0 def get_index_size(self) -> int: if self.index is not None: return self.index.ntotal return 0