|
import os
|
|
import json
|
|
import numpy as np
|
|
import faiss
|
|
from typing import List, Dict
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
class Retriever:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.index = None
|
|
self.meta = {}
|
|
self.embeddings = None
|
|
self._load_index()
|
|
|
|
def _load_index(self):
|
|
try:
|
|
if os.path.exists('data/index/index.faiss') and os.path.exists('data/index/meta.json'):
|
|
self.index = faiss.read_index('data/index/index.faiss')
|
|
self.embeddings = np.load('data/index/embeddings.npy')
|
|
|
|
with open('data/index/meta.json', 'r', encoding='utf-8') as f:
|
|
self.meta = json.load(f)
|
|
|
|
print('Индекс загружен из кэша')
|
|
else:
|
|
print('Индекс не найден, будет создан при первом использовании')
|
|
except Exception as e:
|
|
print(f'Ошибка загрузки индекса: {e}')
|
|
|
|
def _load_model(self):
|
|
if self.model is None:
|
|
try:
|
|
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
|
print('Модель эмбеддингов загружена')
|
|
except Exception as e:
|
|
print(f'Ошибка загрузки модели: {e}')
|
|
raise
|
|
|
|
def _build_index(self, courses: List[Dict]):
|
|
if not courses:
|
|
return
|
|
|
|
self._load_model()
|
|
|
|
texts = []
|
|
meta_data = {}
|
|
|
|
for i, course in enumerate(courses):
|
|
text = f"{course.get('name', '')} {course.get('short_desc', '')}"
|
|
text = text.lower().strip()
|
|
|
|
if len(text) > 220:
|
|
text = text[:220]
|
|
|
|
texts.append(text)
|
|
meta_data[i] = course.get('id', str(i))
|
|
|
|
if not texts:
|
|
return
|
|
|
|
embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
|
|
|
embeddings = embeddings.astype(np.float32)
|
|
faiss.normalize_L2(embeddings)
|
|
|
|
self.index = faiss.IndexFlatIP(embeddings.shape[1])
|
|
self.index.add(embeddings)
|
|
self.embeddings = embeddings
|
|
self.meta = meta_data
|
|
|
|
self._save_index()
|
|
|
|
def _save_index(self):
|
|
try:
|
|
os.makedirs('data/index', exist_ok=True)
|
|
|
|
faiss.write_index(self.index, 'data/index/index.faiss')
|
|
np.save('data/index/embeddings.npy', self.embeddings)
|
|
|
|
with open('data/index/meta.json', 'w', encoding='utf-8') as f:
|
|
json.dump(self.meta, f, ensure_ascii=False, indent=2)
|
|
|
|
print('Индекс сохранен')
|
|
except Exception as e:
|
|
print(f'Ошибка сохранения индекса: {e}')
|
|
|
|
def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]:
|
|
if self.index is None:
|
|
return []
|
|
|
|
self._load_model()
|
|
|
|
query_embedding = self.model.encode([query.lower().strip()], convert_to_numpy=True)
|
|
query_embedding = query_embedding.astype(np.float32)
|
|
faiss.normalize_L2(query_embedding)
|
|
|
|
scores, indices = self.index.search(query_embedding, k)
|
|
|
|
results = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if score >= threshold and idx in self.meta:
|
|
course_id = self.meta[idx]
|
|
results.append({
|
|
'course_id': course_id,
|
|
'score': float(score)
|
|
})
|
|
|
|
return results
|
|
|
|
def build_or_load_index(self, courses: List[Dict] = None):
|
|
if self.index is None and courses:
|
|
print('Создание индекса...')
|
|
self._build_index(courses)
|
|
elif self.index is None:
|
|
print('Индекс не найден и данные не предоставлены')
|
|
|
|
def get_embedding_dim(self) -> int:
|
|
if self.embeddings is not None:
|
|
return self.embeddings.shape[1]
|
|
return 0
|
|
|
|
def get_index_size(self) -> int:
|
|
if self.index is not None:
|
|
return self.index.ntotal
|
|
return 0
|
|
|