File size: 4,598 Bytes
2fc8dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import json
import numpy as np
import faiss
from typing import List, Dict
from sentence_transformers import SentenceTransformer

class Retriever:
    def __init__(self):
        self.model = None
        self.index = None
        self.meta = {}
        self.embeddings = None
        self._load_index()
    
    def _load_index(self):
        try:
            if os.path.exists('data/index/index.faiss') and os.path.exists('data/index/meta.json'):
                self.index = faiss.read_index('data/index/index.faiss')
                self.embeddings = np.load('data/index/embeddings.npy')
                
                with open('data/index/meta.json', 'r', encoding='utf-8') as f:
                    self.meta = json.load(f)
                
                print('Индекс загружен из кэша')
            else:
                print('Индекс не найден, будет создан при первом использовании')
        except Exception as e:
            print(f'Ошибка загрузки индекса: {e}')
    
    def _load_model(self):
        if self.model is None:
            try:
                self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
                print('Модель эмбеддингов загружена')
            except Exception as e:
                print(f'Ошибка загрузки модели: {e}')
                raise
    
    def _build_index(self, courses: List[Dict]):
        if not courses:
            return
        
        self._load_model()
        
        texts = []
        meta_data = {}
        
        for i, course in enumerate(courses):
            text = f"{course.get('name', '')} {course.get('short_desc', '')}"
            text = text.lower().strip()
            
            if len(text) > 220:
                text = text[:220]
            
            texts.append(text)
            meta_data[i] = course.get('id', str(i))
        
        if not texts:
            return
        
        embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
        
        embeddings = embeddings.astype(np.float32)
        faiss.normalize_L2(embeddings)
        
        self.index = faiss.IndexFlatIP(embeddings.shape[1])
        self.index.add(embeddings)
        self.embeddings = embeddings
        self.meta = meta_data
        
        self._save_index()
    
    def _save_index(self):
        try:
            os.makedirs('data/index', exist_ok=True)
            
            faiss.write_index(self.index, 'data/index/index.faiss')
            np.save('data/index/embeddings.npy', self.embeddings)
            
            with open('data/index/meta.json', 'w', encoding='utf-8') as f:
                json.dump(self.meta, f, ensure_ascii=False, indent=2)
            
            print('Индекс сохранен')
        except Exception as e:
            print(f'Ошибка сохранения индекса: {e}')
    
    def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]:
        if self.index is None:
            return []
        
        self._load_model()
        
        query_embedding = self.model.encode([query.lower().strip()], convert_to_numpy=True)
        query_embedding = query_embedding.astype(np.float32)
        faiss.normalize_L2(query_embedding)
        
        scores, indices = self.index.search(query_embedding, k)
        
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if score >= threshold and idx in self.meta:
                course_id = self.meta[idx]
                results.append({
                    'course_id': course_id,
                    'score': float(score)
                })
        
        return results
    
    def build_or_load_index(self, courses: List[Dict] = None):
        if self.index is None and courses:
            print('Создание индекса...')
            self._build_index(courses)
        elif self.index is None:
            print('Индекс не найден и данные не предоставлены')
    
    def get_embedding_dim(self) -> int:
        if self.embeddings is not None:
            return self.embeddings.shape[1]
        return 0
    
    def get_index_size(self) -> int:
        if self.index is not None:
            return self.index.ntotal
        return 0