| |
|
| | from sentence_transformers import SentenceTransformer |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import numpy as np |
| |
|
| | |
| | model = SentenceTransformer('ThanhLe0125/e5-math') |
| |
|
| | print("🧪 Testing MRR-optimized fine-tuned model:") |
| | print("="*50) |
| |
|
| | |
| | query = "query: Định nghĩa hàm số đồng biến" |
| | chunks = [ |
| | "passage: Hàm số đồng biến trên khoảng (a;b) là hàm số mà với mọi x1 < x2 thì f(x1) < f(x2)", |
| | "passage: Ví dụ: Tìm khoảng đồng biến của hàm số y = x^2 - 2x + 1", |
| | "passage: Phương trình bậc hai ax^2 + bx + c = 0 có delta = b^2 - 4ac", |
| | "passage: Tính đạo hàm của hàm số đa thức", |
| | "passage: Giới hạn của dãy số" |
| | ] |
| |
|
| | |
| | query_emb = model.encode([query]) |
| | chunk_embs = model.encode(chunks) |
| | similarities = cosine_similarity(query_emb, chunk_embs)[0] |
| | ranked_indices = similarities.argsort()[::-1] |
| |
|
| | |
| | print("🎯 MRR-Optimized Rankings:") |
| | chunk_types = ["CORRECT", "RELATED", "IRRELEVANT", "IRRELEVANT", "IRRELEVANT"] |
| | for rank, idx in enumerate(ranked_indices, 1): |
| | print(f"Rank {rank}: {chunk_types[idx]:>10} (Score: {similarities[idx]:.4f})") |
| | print(f" {chunks[idx][:70]}...") |
| | print() |
| |
|
| | |
| | correct_rank = None |
| | for rank, idx in enumerate(ranked_indices, 1): |
| | if idx == 0: |
| | correct_rank = rank |
| | break |
| |
|
| | if correct_rank: |
| | mrr = 1.0 / correct_rank |
| | recall_at_k = {} |
| | for k in [1, 2, 3, 4, 5]: |
| | recall_at_k[k] = 1 if correct_rank <= k else 0 |
| | |
| | print(f"📊 Query Metrics:") |
| | print(f" MRR: {mrr:.4f} (correct chunk at rank #{correct_rank})") |
| | print(f" Recall@1: {recall_at_k[1]} | Recall@2: {recall_at_k[2]} | Recall@3: {recall_at_k[3]}") |
| | print(f" Recall@4: {recall_at_k[4]} | Recall@5: {recall_at_k[5]}") |
| | |
| | if correct_rank == 1: |
| | print(" 🌟 PERFECT! Correct chunk at rank #1!") |
| | elif correct_rank <= 2: |
| | print(" 🎯 EXCELLENT! Correct chunk in top 2!") |
| | elif correct_rank <= 3: |
| | print(" 👍 GOOD! Correct chunk in top 3!") |
| | else: |
| | print(" 📈 Could be better - but still found the answer!") |
| |
|
| | print("\n" + "="*50) |
| | print("💡 Fine-tuning Benefits:") |
| | print(" ✅ Pushes correct chunks to rank #1") |
| | print(" ✅ Reduces inference cost (need fewer chunks)") |
| | print(" ✅ Improves user experience (instant answers)") |
| | print(" ✅ Specialized for Vietnamese mathematics") |
| |
|