File size: 11,130 Bytes
f6c9376
 
 
 
 
34991da
28f4bd1
4032184
 
f6c9376
 
4032184
f6c9376
 
 
 
eab288c
f6c9376
 
 
 
 
 
28f4bd1
34991da
28f4bd1
 
 
c025e27
 
34991da
 
 
28f4bd1
 
34991da
28f4bd1
34991da
f6c9376
28f4bd1
 
 
44013a5
28f4bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44013a5
28f4bd1
 
 
 
 
44013a5
28f4bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44013a5
34991da
 
 
28f4bd1
34991da
 
 
 
28f4bd1
34991da
 
9dcf8cb
 
 
 
28f4bd1
34991da
 
28f4bd1
34991da
 
28f4bd1
 
34991da
 
 
 
 
28f4bd1
34991da
 
 
28f4bd1
34991da
 
28f4bd1
 
 
 
 
 
 
 
 
 
34991da
28f4bd1
 
 
 
 
 
 
 
34991da
 
28f4bd1
 
 
34991da
28f4bd1
34991da
28f4bd1
34991da
 
 
 
 
 
 
 
 
 
 
28f4bd1
 
 
 
 
014529e
 
28f4bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c9376
 
 
28f4bd1
f6c9376
8b81b1d
44013a5
 
 
 
34991da
 
28f4bd1
 
 
34991da
44013a5
34991da
c025e27
34991da
 
44013a5
34991da
 
 
 
 
 
 
 
 
96f79c9
34991da
 
96f79c9
34991da
 
 
 
44013a5
f6c9376
44013a5
 
28f4bd1
 
34991da
44013a5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from typing import List, Dict
from .config import get_settings
from .gemini_client import GeminiClient
from loguru import logger
import asyncio
import hashlib
import time
# from .constants import BATCH_STATUS_MESSAGES
# from .utils import get_random_message

class Reranker:
    def __init__(self):
        settings = get_settings()
        self.provider = getattr(settings, 'rerank_provider', settings.llm_provider)
        self.model = getattr(settings, 'rerank_model', settings.llm_model)
        if self.provider == 'gemini':
            self.client = GeminiClient()
        # elif self.provider == 'openai':
        #     self.client = OpenAIClient(settings.openai_api_key, model=self.model)
        # elif self.provider == 'cohere':
        #     self.client = CohereClient(settings.cohere_api_key, model=self.model)
        else:
            raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
        # Cải thiện cache với TTL và quản lý memory
        self._rerank_cache = {}
        self._cache_ttl = 3600  # 1 giờ
        self._max_cache_size = 200  # Tăng cache size
        self._cache_timestamps = {}
        # Sử dụng max_docs_to_rerank từ config
        self.max_docs_to_rerank = settings.max_docs_to_rerank

    def _get_cache_key(self, query: str, docs: List[Dict]) -> str:
        """Tạo cache key từ query và docs."""
        # Tối ưu hóa cache key generation
        query_normalized = query.lower().strip()
        doc_ids = [str(doc.get('id', '')) for doc in docs[:15]]  # Chỉ cache top 15 docs
        cache_content = query_normalized + "|".join(sorted(doc_ids))
        return hashlib.md5(cache_content.encode()).hexdigest()

    def _clean_cache(self):
        """Dọn dẹp cache cũ và quản lý memory."""
        current_time = time.time()
        
        # Xóa cache entries đã hết hạn
        expired_keys = [
            key for key, timestamp in self._cache_timestamps.items()
            if current_time - timestamp > self._cache_ttl
        ]
        
        for key in expired_keys:
            del self._rerank_cache[key]
            del self._cache_timestamps[key]
        
        # Nếu cache vẫn quá lớn, xóa entries cũ nhất
        if len(self._rerank_cache) > self._max_cache_size:
            sorted_keys = sorted(
                self._cache_timestamps.keys(),
                key=lambda k: self._cache_timestamps[k]
            )
            
            # Xóa 20% cache entries cũ nhất
            keys_to_remove = sorted_keys[:len(sorted_keys) // 5]
            for key in keys_to_remove:
                del self._rerank_cache[key]
                del self._cache_timestamps[key]
            
            logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries")

    def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]:
        """Lấy kết quả từ cache nếu có và còn hợp lệ."""
        if cache_key in self._rerank_cache:
            current_time = time.time()
            if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl:
                cached_result = self._rerank_cache[cache_key][:top_k]
                logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results")
                return cached_result
            else:
                # Cache đã hết hạn, xóa
                del self._rerank_cache[cache_key]
                del self._cache_timestamps[cache_key]
        
        return []

    def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
        """Lưu kết quả vào cache."""
        self._rerank_cache[cache_key] = scored_docs
        self._cache_timestamps[cache_key] = time.time()
        
        # Dọn dẹp cache nếu cần
        if len(self._rerank_cache) > self._max_cache_size:
            self._clean_cache()

    async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
        """
        Score nhiều documents cùng lúc bằng một prompt duy nhất.
        Không cắt bớt nội dung luật.
        """
        if not docs:
            return []
        
        # Không giới hạn content length, giữ nguyên nội dung luật
        docs_content = []
        for i, doc in enumerate(docs):
            # tieude = (doc.get('tieude') or '').strip()
            # noidung = (doc.get('noidung') or '').strip()
            # content = f"{tieude} {noidung}".strip()
            content = (doc.get('fullcontent') or '').strip()
            docs_content.append(f"{i+1}. {content}")
        
        batch_prompt = (
            f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
            f"Câu hỏi: {query}\n\n"
            f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n"
            f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
            f"Ví dụ: 8,5,7,3,9"
        )
        
        try:
            if self.provider == 'gemini':
                loop = asyncio.get_event_loop()
                logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs")
                response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt)
                logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
                
                # Cải thiện parsing scores
                scores_text = str(response).strip()
                scores = []
                
                # Xử lý nhiều format response có thể có
                if ',' in scores_text:
                    score_parts = scores_text.split(',')
                elif ' ' in scores_text:
                    score_parts = scores_text.split()
                else:
                    score_parts = scores_text.replace('.', ',').split(',')
                
                for score_str in score_parts:
                    try:
                        clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.')
                        if clean_score:
                            score = float(clean_score)
                            score = max(0, min(10, score))
                            scores.append(score)
                        else:
                            scores.append(0)
                    except (ValueError, TypeError):
                        scores.append(0)
                
                while len(scores) < len(docs):
                    scores.append(0)
                
                for i, doc in enumerate(docs):
                    doc['rerank_score'] = scores[i]
                
                logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}")
                return docs
                
            else:
                raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.")
            
        except Exception as e:
            logger.error(f"[RERANK] Lỗi khi batch score: {e}")
            for doc in docs:
                doc['rerank_score'] = 0
            return docs

    async def _score_doc(self, query: str, doc: Dict) -> Dict:
        """
        Score một document với query.
        Không cắt bớt nội dung luật.
        """
        tieude = (doc.get('tieude') or '').strip()
        noidung = (doc.get('noidung') or '').strip()
        content = f"{tieude} {noidung}".strip()
        prompt = (
            f"Đánh giá mức độ liên quan:\n"
            f"Luật: {content}\n"
            f"Hỏi: {query}\n"
            f"Điểm (0-10):"
        )
        try:
            if self.provider == 'gemini':
                loop = asyncio.get_event_loop()
                logger.info(f"[RERANK] Sending individual prompt to Gemini")
                score_response = await loop.run_in_executor(None, self.client.generate_text, prompt)
                logger.info(f"[RERANK] Got individual score from Gemini: {score_response}")
                score_text = str(score_response).strip()
                try:
                    clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.')
                    if clean_score:
                        score = float(clean_score)
                        score = max(0, min(10, score))
                    else:
                        score = 0
                except (ValueError, TypeError):
                    score = 0
                doc['rerank_score'] = score
                return doc
            else:
                raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
        except Exception as e:
            logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
            doc['rerank_score'] = 0
            return doc

    async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
        """
        Rerank docs theo độ liên quan với query, trả về top_k docs.
        Sử dụng batch processing và caching để tối ưu hiệu suất.
        """
        logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
        
        if not docs:
            return []
        
        # Kiểm tra cache trước
        cache_key = self._get_cache_key(query, docs)
        cached_result = self._get_cached_result(cache_key, top_k)
        
        if cached_result:
            return cached_result
        
        # Giới hạn số lượng docs để rerank - chỉ rerank top 15 docs có similarity cao nhất
        max_docs_to_rerank = self.max_docs_to_rerank
        docs_to_rerank = docs[:max_docs_to_rerank]
        logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})")
        
        # Sử dụng batch processing thay vì individual scoring
        try:
            scored = await self._batch_score_docs(query, docs_to_rerank)
            logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs")
        except Exception as e:
            logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}")
            # Fallback về individual scoring nếu batch processing thất bại
            scored = []
            for doc in docs_to_rerank:
                try:
                    scored_doc = await self._score_doc(query, doc)
                    scored.append(scored_doc)
                except Exception as e:
                    logger.error(f"[RERANK] Error scoring individual doc: {e}")
                    doc['rerank_score'] = 0
                    scored.append(doc)
        
        # Sort theo score và trả về top_k
        scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
        result = scored[:top_k]
        
        # Cache kết quả với system mới
        self._set_cached_result(cache_key, scored)
        
        logger.info(f"[RERANK] Top reranked docs: {result}")
        return result