Medical-VQA / src /utils /metrics.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
"""Evaluation metrics for VQA: Accuracy, EM, F1, BLEU-1~4, METEOR, and Semantic Score."""
from __future__ import annotations
from collections import Counter
import numpy as np
import torch
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from nltk.translate.meteor_score import meteor_score as _nltk_meteor
import nltk
try:
nltk.data.find('corpora/wordnet')
except LookupError:
print("[INFO] Đang tự động tải bộ từ điển NLTK WordNet cho METEOR score...")
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
# 1. Semantic Score (SentenceTransformer)
try:
from sentence_transformers import SentenceTransformer, util
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
semantic_model = None
print(f"Warning: Could not load SentenceTransformer: {e}")
# 2. BERTScore
try:
from bert_score import BERTScorer
# Ép sử dụng model multilingual để tránh lỗi attribute của Tokenizer trên Python 3.12
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_scorer = BERTScorer(model_type="bert-base-multilingual-cased", device=device)
except ImportError:
print("[WARNING] Thư viện bert_score chưa được cài đặt.")
bert_scorer = None
except Exception as e:
bert_scorer = None
print(f"Warning: Could not load BERTScorer: {e}")
# 3. ROUGE-L
try:
from rouge_score import rouge_scorer as rs
rouge_l_scorer = rs.RougeScorer(['rougeL'], use_stemmer=True)
except Exception as e:
rouge_l_scorer = None
print(f"Warning: Could not load rouge-score: {e}")
# [FIX] Import from the local text_utils instead of non-existent src.data.preprocessing
from .text_utils import normalize_answer, majority_answer
def compute_rouge_l(pred: str, refs) -> float:
"""Tính ROUGE-L (Lấy MAX over multiple refs)."""
if not rouge_l_scorer: return 0.0
if isinstance(refs, str): refs = [refs]
best_rouge = 0.0
for r in refs:
score = rouge_l_scorer.score(normalize_answer(r), normalize_answer(pred))['rougeL'].fmeasure
best_rouge = max(best_rouge, score)
return best_rouge
def compute_bertscore(preds: list[str], refs: list) -> float:
"""Tính BERTScore cho cả batch."""
if not bert_scorer or not preds or not refs:
return 0.0
clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
clean_refs = [r if r.strip() else "." for r in clean_refs]
try:
# Tăng tốc bằng cách tắt idf nếu cần
P, R, F1 = bert_scorer.score(clean_preds, clean_refs)
return float(F1.mean().item())
except Exception as e:
print(f"[WARNING] BERTScore error: {e}")
return 0.0
def compute_exact_match(pred: str, refs) -> float:
"""So khớp chính xác lấy MAX (soft match over multiple refs)."""
if isinstance(refs, str): refs = [refs]
return float(any(normalize_answer(pred) == normalize_answer(r) for r in refs))
def compute_f1(pred: str, refs) -> float:
"""Tính F1-score ở mức độ token. Lấy MAX over multiple refs."""
if isinstance(refs, str): refs = [refs]
best_f1 = 0.0
p_toks = normalize_answer(pred).split()
for r in refs:
r_toks = normalize_answer(r).split()
if not p_toks or not r_toks:
f1 = float(p_toks == r_toks)
else:
common = Counter(p_toks) & Counter(r_toks)
num_same = sum(common.values())
if num_same == 0:
f1 = 0.0
else:
precision = num_same / len(p_toks)
recall = num_same / len(r_toks)
f1 = 2 * precision * recall / (precision + recall)
best_f1 = max(best_f1, f1)
return best_f1
def compute_bleu(pred: str, refs) -> dict[str, float]:
"""Tính BLEU from 1 đến 4 sử dụng corpus-level refs."""
if isinstance(refs, str): refs = [refs]
smoothie = SmoothingFunction().method4
p_toks = normalize_answer(pred).split()
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
if not p_toks or not r_toks_list:
return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
weights = [
(1, 0, 0, 0), # BLEU-1
(0.5, 0.5, 0, 0), # BLEU-2
(0.33, 0.33, 0.33, 0), # BLEU-3
(0.25, 0.25, 0.25, 0.25) # BLEU-4
]
return {
f"bleu{i+1}": sentence_bleu(r_toks_list, p_toks, weights=w, smoothing_function=smoothie)
for i, w in enumerate(weights)
}
def compute_meteor(pred: str, refs) -> float:
"""Tính METEOR score (hỗ trợ N refs)."""
if isinstance(refs, str): refs = [refs]
p_toks = normalize_answer(pred).split()
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
if not p_toks or not r_toks_list:
return 0.0
return _nltk_meteor(r_toks_list, p_toks)
def compute_vqa_accuracy(pred: str, direct_answers) -> float:
"""
Tính VQA Accuracy mềm: min(#người_cùng_đáp_án / 3, 1.0).
Using cho các tập dữ liệu có nhiều người gắn nhãn (như A-OKVQA).
"""
if isinstance(direct_answers, str):
return compute_exact_match(pred, direct_answers)
normed_pred = normalize_answer(pred)
matches = sum(1 for a in direct_answers if normalize_answer(a) == normed_pred)
return min(matches / 3.0, 1.0)
def compute_semantic_score(preds: list[str], refs: list) -> float:
"""Tính điểm tương đồng ngữ nghĩa bằng Cosine Similarity."""
if not semantic_model or not preds or not refs:
return 0.0
clean_preds = [normalize_answer(p) for p in preds]
# Take the most representative string if it's a list for semantic comparison
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
# Encode to Vector (Embeddings)
pred_embs = semantic_model.encode(clean_preds, convert_to_tensor=True, show_progress_bar=False)
ref_embs = semantic_model.encode(clean_refs, convert_to_tensor=True, show_progress_bar=False)
# Compute Cosine distance matrix and take diagonal (1-to-1 comparison)
cosine_scores = util.cos_sim(pred_embs, ref_embs)
scores = torch.diag(cosine_scores)
return float(scores.mean().item())
def batch_metrics(predictions: list[str], references: list) -> dict[str, float]:
"""Tổng hợp toàn bộ chỉ số đo lường trên batch."""
results = {
"accuracy": [], "em": [], "f1": [], "meteor": [],
"bleu1": [], "bleu2": [], "bleu3": [], "bleu4": [],
"rouge_l": []
}
for pred, ref in zip(predictions, references):
# Pass full refs list to compute_f1, compute_bleu to maximize score
results["accuracy"].append(compute_vqa_accuracy(pred, ref))
results["em"].append(compute_exact_match(pred, ref))
results["f1"].append(compute_f1(pred, ref))
results["meteor"].append(compute_meteor(pred, ref))
results["rouge_l"].append(compute_rouge_l(pred, ref))
bleus = compute_bleu(pred, ref)
for k, v in bleus.items():
results[k].append(v)
# Average traditional metrics
final_metrics = {k: float(np.mean(v)) for k, v in results.items()}
# Compute Semantic Score and BERTScore for entire batch
final_metrics["semantic"] = compute_semantic_score(predictions, references)
final_metrics["bert_score"] = compute_bertscore(predictions, references)
return final_metrics