test5 / retriever.py
vydrking's picture
Upload 18 files
2fc8dc5 verified
import os
import json
import numpy as np
import faiss
from typing import List, Dict
from sentence_transformers import SentenceTransformer
class Retriever:
def __init__(self):
self.model = None
self.index = None
self.meta = {}
self.embeddings = None
self._load_index()
def _load_index(self):
try:
if os.path.exists('data/index/index.faiss') and os.path.exists('data/index/meta.json'):
self.index = faiss.read_index('data/index/index.faiss')
self.embeddings = np.load('data/index/embeddings.npy')
with open('data/index/meta.json', 'r', encoding='utf-8') as f:
self.meta = json.load(f)
print('Индекс загружен из кэша')
else:
print('Индекс не найден, будет создан при первом использовании')
except Exception as e:
print(f'Ошибка загрузки индекса: {e}')
def _load_model(self):
if self.model is None:
try:
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
print('Модель эмбеддингов загружена')
except Exception as e:
print(f'Ошибка загрузки модели: {e}')
raise
def _build_index(self, courses: List[Dict]):
if not courses:
return
self._load_model()
texts = []
meta_data = {}
for i, course in enumerate(courses):
text = f"{course.get('name', '')} {course.get('short_desc', '')}"
text = text.lower().strip()
if len(text) > 220:
text = text[:220]
texts.append(text)
meta_data[i] = course.get('id', str(i))
if not texts:
return
embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings)
self.index = faiss.IndexFlatIP(embeddings.shape[1])
self.index.add(embeddings)
self.embeddings = embeddings
self.meta = meta_data
self._save_index()
def _save_index(self):
try:
os.makedirs('data/index', exist_ok=True)
faiss.write_index(self.index, 'data/index/index.faiss')
np.save('data/index/embeddings.npy', self.embeddings)
with open('data/index/meta.json', 'w', encoding='utf-8') as f:
json.dump(self.meta, f, ensure_ascii=False, indent=2)
print('Индекс сохранен')
except Exception as e:
print(f'Ошибка сохранения индекса: {e}')
def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]:
if self.index is None:
return []
self._load_model()
query_embedding = self.model.encode([query.lower().strip()], convert_to_numpy=True)
query_embedding = query_embedding.astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.index.search(query_embedding, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if score >= threshold and idx in self.meta:
course_id = self.meta[idx]
results.append({
'course_id': course_id,
'score': float(score)
})
return results
def build_or_load_index(self, courses: List[Dict] = None):
if self.index is None and courses:
print('Создание индекса...')
self._build_index(courses)
elif self.index is None:
print('Индекс не найден и данные не предоставлены')
def get_embedding_dim(self) -> int:
if self.embeddings is not None:
return self.embeddings.shape[1]
return 0
def get_index_size(self) -> int:
if self.index is not None:
return self.index.ntotal
return 0