Spaces:
Paused
Paused
| import wandb | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| import os | |
| import csv | |
| import json | |
| from src.utils.early_stopping import DynamicClassWeights | |
| class MedicalVQATrainer: | |
| def __init__(self, model, train_loader, val_loader, optimizer, device, config, scheduler=None, pad_token_id=0, beam_width=1): | |
| self.model = model | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| self.device = device | |
| self.config = config | |
| self.beam_width = beam_width | |
| # [FIX] Dynamic class weights computed from actual data distribution | |
| # Replaces hard-coded [1.0, 2.5] which may not match real imbalance ratio | |
| dynamic_weights = DynamicClassWeights.compute_weights(train_loader, device=device) | |
| self.criterion_closed = nn.CrossEntropyLoss(weight=dynamic_weights) | |
| # [NOTE] Label smoothing only on open-ended head: closed-head needs sharp 0/1 | |
| self.criterion_open = nn.CrossEntropyLoss( | |
| ignore_index=pad_token_id, | |
| label_smoothing=config['train'].get('label_smoothing', 0.0) | |
| ) | |
| self.criterion_closed_hard = nn.CrossEntropyLoss(weight=dynamic_weights) # no smoothing | |
| # AMP (Automatic Mixed Precision) | |
| self.use_amp = config['train'].get('use_amp', False) and device.type == 'cuda' | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) | |
| self.history = [] | |
| def _flatten_dict(data, parent_key="", sep="."): | |
| items = {} | |
| for key, value in data.items(): | |
| new_key = f"{parent_key}{sep}{key}" if parent_key else str(key) | |
| if isinstance(value, dict): | |
| items.update(MedicalVQATrainer._flatten_dict(value, new_key, sep=sep)) | |
| elif isinstance(value, (list, tuple)): | |
| continue | |
| else: | |
| items[new_key] = value | |
| return items | |
| def save_history(self, output_dir): | |
| os.makedirs(output_dir, exist_ok=True) | |
| json_path = os.path.join(output_dir, "history.json") | |
| csv_path = os.path.join(output_dir, "history.csv") | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(self.history, f, ensure_ascii=False, indent=2) | |
| flat_rows = [self._flatten_dict(row) for row in self.history] | |
| if flat_rows: | |
| fieldnames = sorted({key for row in flat_rows for key in row.keys()}) | |
| with open(csv_path, "w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(flat_rows) | |
| def _compute_closed_weights(train_loader): | |
| """ΔαΊΏm phΓ’n phα»i Yes/No vΓ tΓnh inverse frequency weights.""" | |
| counts = {0: 0, 1: 0} # 0=khΓ΄ng, 1=cΓ³ | |
| for batch in train_loader: | |
| labels = batch['label_closed'] | |
| for lbl in labels: | |
| v = lbl.item() | |
| if v in counts: | |
| counts[v] += 1 | |
| total = counts[0] + counts[1] | |
| if total == 0: | |
| return torch.ones(2) | |
| # Inverse frequency: class Γt mαΊ«u β weight cao hΖ‘n | |
| w0 = total / (2 * max(counts[0], 1)) | |
| w1 = total / (2 * max(counts[1], 1)) | |
| weights = torch.tensor([w0, w1], dtype=torch.float32) | |
| print(f"[INFO] Closed question distribution: khΓ΄ng={counts[0]}, cΓ³={counts[1]}") | |
| return weights | |
| def train_epoch(self, epoch): | |
| self.model.train() | |
| total_loss = 0 | |
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") | |
| # [OPTIMIZATION] Gradient accumulation for larger effective batch size | |
| accumulation_steps = self.config['train'].get('gradient_accumulation_steps', 2) | |
| for batch_idx, batch in enumerate(pbar): | |
| images = batch['image'].to(self.device) | |
| input_ids = batch['input_ids'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| label_closed = batch['label_closed'].to(self.device) | |
| target_ids = batch['target_ids'].to(self.device) | |
| # Zero gradients only at the beginning or after optimizer step | |
| if batch_idx % accumulation_steps == 0: | |
| self.optimizer.zero_grad() | |
| # Sα» dα»₯ng AMP Autocast | |
| with torch.cuda.amp.autocast(enabled=self.use_amp): | |
| # Teacher Forcing: Input lΓ <s> A B, Target lΓ A B </s> | |
| decoder_input = target_ids[:, :-1] | |
| decoder_target = target_ids[:, 1:] | |
| logits_closed, logits_open = self.model(images, input_ids, attention_mask, decoder_input) | |
| # Loss calculation | |
| loss = 0 | |
| mask_closed = (label_closed != -1) | |
| if mask_closed.any(): | |
| loss += self.criterion_closed(logits_closed[mask_closed], label_closed[mask_closed]) | |
| # PhΓ’n tΓ‘ch Loss Generator Δα» chα»ng Mode Collapse (LΖ°α»i biαΊΏng) | |
| vocab_size = logits_open.size(-1) | |
| mask_open = (label_closed == -1) | |
| # 1. CΓ’u hα»i Yes/No: GiαΊ£m trα»ng sα» xuα»ng cα»±c thαΊ₯p (0.1) Δα» model khΓ΄ng bα» thiΓͺn vα» | |
| if mask_closed.any(): | |
| loss_gen_closed = self.criterion_open(logits_open[mask_closed].reshape(-1, vocab_size), decoder_target[mask_closed].reshape(-1)) | |
| loss += loss_gen_closed * 0.1 | |
| # 2. CΓ’u hα»i Mα»: TΔng trα»ng sα» + Length Penalty + Coverage Penalty | |
| if mask_open.any(): | |
| open_logits = logits_open[mask_open] | |
| open_targets = decoder_target[mask_open] | |
| loss_gen_open = self.criterion_open(open_logits.reshape(-1, vocab_size), open_targets.reshape(-1)) | |
| # Length penalty: phαΊ‘t nαΊΏu model sinh quΓ‘ Γt token cΓ³ nghΔ©a | |
| pred_lengths = (open_targets != self.criterion_open.ignore_index).float().sum(dim=-1).mean() | |
| length_penalty = torch.clamp(1.0 - pred_lengths / 15.0, min=0.0) | |
| # Thay coverage loss bαΊ±ng entropy penalty (ΔΓΊng hΖ‘n) | |
| # PhαΊ‘t khi model quΓ‘ confident vΓ o 1 token | |
| probs = torch.softmax(open_logits, dim=-1) # [N, seq, vocab] | |
| entropy = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean() | |
| coverage_loss = torch.clamp(2.0 - entropy, min=0.0) # phαΊ‘t nαΊΏu entropy < 2.0 | |
| # [TUNED] Reduce weight 3.0β2.0: open head was dominating, | |
| # causing closed-head accuracy to plateau (observed in A1/A2 runs) | |
| open_loss_weight = self.config.get('open_loss_weight', 2.0) | |
| loss += (loss_gen_open + 0.3 * length_penalty + 0.1 * coverage_loss) * open_loss_weight | |
| # [OPTIMIZATION] Normalize loss by accumulation steps for proper gradient scaling | |
| loss = loss / accumulation_steps | |
| # Backward vα»i GradScaler | |
| self.scaler.scale(loss).backward() | |
| # [OPTIMIZATION] Update weights only after accumulating gradients | |
| is_last_batch = (batch_idx + 1) == len(self.train_loader) | |
| if (batch_idx + 1) % accumulation_steps == 0 or is_last_batch: | |
| # Gradient Clipping | |
| if self.config['train'].get('grad_clip'): | |
| self.scaler.unscale_(self.optimizer) | |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.config['train']['grad_clip']) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| # [CRITICAL FIX] Step scheduler sau mα»i batch thay vΓ¬ epoch Δα» warmup mượt hΖ‘n | |
| if self.scheduler: | |
| self.scheduler.step() | |
| total_loss += loss.item() * accumulation_steps | |
| # [FIX] Log LR cho tα»«ng param group β hiα»n thα» decoder LR (group cuα»i) trΓͺn progress bar | |
| decoder_lr = self.optimizer.param_groups[-1]['lr'] | |
| vision_lr = self.optimizer.param_groups[0]['lr'] | |
| if wandb.run: | |
| wandb.log({ | |
| "batch_loss": loss.item(), | |
| "lr_vision": vision_lr, | |
| "lr_decoder": decoder_lr, | |
| }) | |
| pbar.set_postfix({"loss": f"{loss.item():.3f}", "dec_lr": f"{decoder_lr:.1e}", "vis_lr": f"{vision_lr:.1e}"}) | |
| epoch_train_loss = total_loss / len(self.train_loader) | |
| if wandb.run: | |
| wandb.log({"train_loss_epoch": epoch_train_loss}) | |
| return epoch_train_loss | |
| def val_epoch(self, tokenizer, epoch=0): | |
| """ | |
| Thα»±c hiα»n ΔΓ‘nh giΓ‘ trΓͺn tαΊp Validation sau mα»i Epoch. | |
| """ | |
| from src.engine.medical_eval import evaluate_vqa | |
| max_ans_len = self.config.get('data', {}).get('max_answer_len', 32) | |
| max_words = self.config.get('data', {}).get('answer_max_words', 10) | |
| print(f"\nπ Δang chαΊ‘y Validation cho Epoch {epoch} (max_ans_len={max_ans_len})...") | |
| metrics = evaluate_vqa( | |
| self.model, | |
| self.val_loader, | |
| self.device, | |
| tokenizer, | |
| beam_width=self.beam_width, | |
| max_len=max_ans_len, | |
| max_words=max_words | |
| ) | |
| # In cΓ‘c metrics quan trα»ng | |
| print( | |
| f"[METRICS] Accuracy: {metrics.get('accuracy_normalized', metrics['accuracy']):.4f} | " | |
| f"F1: {metrics.get('f1_normalized', metrics['f1']):.4f} | " | |
| f"BLEU-4: {metrics.get('bleu4_normalized', metrics['bleu4']):.4f}" | |
| ) | |
| if wandb.run: | |
| wandb.log({ | |
| "epoch": epoch, | |
| "val_accuracy": metrics["accuracy"], | |
| "val_accuracy_normalized": metrics.get("accuracy_normalized", metrics["accuracy"]), | |
| "val_f1": metrics["f1"], | |
| "val_f1_normalized": metrics.get("f1_normalized", metrics["f1"]), | |
| "val_bleu4": metrics["bleu4"], | |
| "val_bleu4_normalized": metrics.get("bleu4_normalized", metrics["bleu4"]), | |
| "val_bert_score": metrics.get("bert_score", 0), | |
| "val_bert_score_raw": metrics.get("bert_score_raw", metrics.get("bert_score", 0)), | |
| "val_semantic_raw": metrics.get("semantic_raw", metrics.get("semantic", 0)), | |
| }) | |
| return metrics | |
| def train(self, epochs, tokenizer=None): | |
| best_val_acc = 0.0 | |
| patience = self.config['train'].get('patience', 10) | |
| counter = 0 | |
| ckpt_dir = "checkpoints" | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| history_dir = self.config.get("history_dir") | |
| print(f"[INFO] BαΊ―t ΔαΊ§u huαΊ₯n luyα»n trong {epochs} epochs...") | |
| # Log to WandB if available | |
| if wandb.run is not None: | |
| wandb.config.update({ | |
| 'total_epochs': epochs, | |
| 'patience': patience, | |
| 'variant': self.config.get('variant', 'Unknown'), | |
| 'device': str(self.device), | |
| 'use_amp': self.use_amp, | |
| }) | |
| for epoch in range(1, epochs + 1): | |
| train_loss = self.train_epoch(epoch) | |
| metrics = self.val_epoch(tokenizer, epoch=epoch) | |
| val_acc = metrics.get('accuracy_normalized', metrics.get('accuracy', 0)) | |
| closed_eval = metrics.get("closed_eval", {}) | |
| open_eval = metrics.get("open_eval", {}) | |
| is_best = val_acc > best_val_acc | |
| epoch_record = { | |
| "epoch": epoch, | |
| "train_loss": float(train_loss), | |
| "val_accuracy": float(metrics.get("accuracy", 0.0)), | |
| "val_accuracy_normalized": float(metrics.get("accuracy_normalized", metrics.get("accuracy", 0.0))), | |
| "val_f1": float(metrics.get("f1", 0.0)), | |
| "val_f1_normalized": float(metrics.get("f1_normalized", metrics.get("f1", 0.0))), | |
| "val_bleu4": float(metrics.get("bleu4", 0.0)), | |
| "val_bleu4_normalized": float(metrics.get("bleu4_normalized", metrics.get("bleu4", 0.0))), | |
| "val_bert_score": float(metrics.get("bert_score", 0.0)), | |
| "val_bert_score_raw": float(metrics.get("bert_score_raw", metrics.get("bert_score", 0.0))), | |
| "val_semantic_raw": float(metrics.get("semantic_raw", metrics.get("semantic", 0.0))), | |
| "val_closed_accuracy": float(closed_eval.get("accuracy", metrics.get("closed", {}).get("accuracy", -1))), | |
| "val_closed_em": float(closed_eval.get("em", metrics.get("closed", {}).get("em", -1))), | |
| "val_closed_f1": float(closed_eval.get("f1", metrics.get("closed", {}).get("f1", -1))), | |
| "val_open_accuracy": float(metrics.get("open", {}).get("accuracy", -1)), | |
| "val_open_semantic": float(open_eval.get("semantic", metrics.get("open", {}).get("semantic", -1))), | |
| "val_open_bertscore": float(open_eval.get("bert_score", metrics.get("open", {}).get("bert_score", -1))), | |
| "val_open_f1": float(open_eval.get("f1", metrics.get("open", {}).get("f1", -1))), | |
| "val_open_rouge_l": float(open_eval.get("rouge_l", metrics.get("open", {}).get("rouge_l", -1))), | |
| "best_so_far": bool(is_best), | |
| "metrics": metrics, | |
| } | |
| self.history.append(epoch_record) | |
| # Kiα»m tra vΓ LΖ°u Best Checkpoint | |
| if is_best: | |
| best_val_acc = val_acc | |
| counter = 0 | |
| variant = self.config.get('variant', 'A') | |
| save_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_best.pth") | |
| torch.save(self.model.state_dict(), save_path) | |
| resume_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_resume.pth") | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, | |
| 'best_val_acc': best_val_acc, | |
| 'train_loss': float(train_loss), | |
| } | |
| torch.save(checkpoint, resume_path) | |
| print(f"π Best model saved with Accuracy: {val_acc:.4f}") | |
| else: | |
| counter += 1 | |
| if history_dir: | |
| self.save_history(history_dir) | |
| if counter >= patience: | |
| print(f"π Early stopping tαΊ‘i epoch {epoch}!") | |
| break | |
| print("[INFO] HuαΊ₯n luyα»n hoΓ n tαΊ₯t.") | |
| if history_dir: | |
| self.save_history(history_dir) | |
| # ββ Auto-plot sau khi training kαΊΏt thΓΊc ββββββββββββββββββββββββββββββ | |
| if history_dir and len(self.history) >= 1: | |
| chart_paths = self.plot_training_results(history_dir) | |
| print(f"[INFO] π ΔΓ£ lΖ°u {len(chart_paths)} biα»u Δα» tαΊ‘i: {history_dir}") | |
| return self.history | |
| # ββ Visualization ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_training_results(self, output_dir: str) -> list: | |
| """ | |
| Tα»± Δα»ng vαΊ½ vΓ lΖ°u 4 biα»u Δα» sau khi training kαΊΏt thΓΊc: | |
| 1. Train Loss theo epoch | |
| 2. Val Accuracy + F1 + BLEU-4 (multi-metric) | |
| 3. Closed vs Open Accuracy (bar per epoch) | |
| 4. BERTScore + Semantic Score | |
| TrαΊ£ vα» list cΓ‘c ΔΖ°α»ng dαΊ«n file αΊ£nh ΔΓ£ lΖ°u. | |
| """ | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") # Non-interactive backend (an toΓ n cho server) | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| except ImportError: | |
| print("[WARNING] matplotlib chΖ°a cΓ i β bα» qua vαΊ½ biα»u Δα».") | |
| return [] | |
| os.makedirs(output_dir, exist_ok=True) | |
| variant = self.config.get('variant', 'Model') | |
| epochs = [r["epoch"] for r in self.history] | |
| saved = [] | |
| # Palette | |
| COLORS = { | |
| "loss": "#e74c3c", | |
| "accuracy": "#2ecc71", | |
| "f1": "#3498db", | |
| "bleu4": "#9b59b6", | |
| "bert": "#e67e22", | |
| "semantic": "#1abc9c", | |
| "closed": "#2980b9", | |
| "open": "#e74c3c", | |
| } | |
| def _finish(fig, fname): | |
| fig.tight_layout() | |
| path = os.path.join(output_dir, fname) | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| # Upload to WandB if available | |
| if wandb.run: | |
| wandb.log({fname.replace(".png", ""): wandb.Image(path)}) | |
| saved.append(path) | |
| # ββ Chart 1: Train Loss ββββββββββββββββββββββββββββββββββββββββββββββ | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| ax.plot(epochs, [r["train_loss"] for r in self.history], | |
| color=COLORS["loss"], linewidth=2.5, marker="o", markersize=5, | |
| label="Train Loss") | |
| ax.set_title(f"[{variant}] Train Loss per Epoch", fontsize=14, fontweight="bold") | |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Loss") | |
| ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) | |
| ax.legend(); ax.grid(True, alpha=0.3) | |
| _finish(fig, f"{variant}_01_train_loss.png") | |
| # ββ Chart 2: Validation Metrics (Acc / F1 / BLEU-4) βββββββββββββββββ | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(epochs, [r["val_accuracy_normalized"] for r in self.history], | |
| color=COLORS["accuracy"], linewidth=2.5, marker="o", label="Accuracy") | |
| ax.plot(epochs, [r["val_f1_normalized"] for r in self.history], | |
| color=COLORS["f1"], linewidth=2.5, marker="s", label="F1") | |
| ax.plot(epochs, [r["val_bleu4_normalized"] for r in self.history], | |
| color=COLORS["bleu4"], linewidth=2.5, marker="^", label="BLEU-4") | |
| # Mark best epoch | |
| best_epoch = max(self.history, key=lambda r: r["val_accuracy_normalized"]) | |
| ax.axvline(x=best_epoch["epoch"], color="gray", linestyle="--", alpha=0.6, | |
| label=f"Best epoch {best_epoch['epoch']} ({best_epoch['val_accuracy_normalized']:.2%})") | |
| ax.set_title(f"[{variant}] Validation Metrics per Epoch", fontsize=14, fontweight="bold") | |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Score") | |
| ax.set_ylim(0, 1.05) | |
| ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| ax.legend(loc="lower right"); ax.grid(True, alpha=0.3) | |
| _finish(fig, f"{variant}_02_val_metrics.png") | |
| # ββ Chart 3: Closed vs Open Accuracy ββββββββββββββββββββββββββββββββ | |
| closed_vals = [r["val_closed_accuracy"] for r in self.history] | |
| open_vals = [r["val_open_accuracy"] for r in self.history] | |
| has_closed = any(v >= 0 for v in closed_vals) | |
| has_open = any(v >= 0 for v in open_vals) | |
| if has_closed or has_open: | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| w = 0.35 | |
| x = range(len(epochs)) | |
| if has_closed: | |
| c_vals = [v if v >= 0 else 0 for v in closed_vals] | |
| ax.bar([i - w/2 for i in x], c_vals, w, label="Closed (Yes/No)", | |
| color=COLORS["closed"], alpha=0.85) | |
| if has_open: | |
| o_vals = [v if v >= 0 else 0 for v in open_vals] | |
| ax.bar([i + w/2 for i in x], o_vals, w, label="Open-ended", | |
| color=COLORS["open"], alpha=0.85) | |
| ax.set_xticks(list(x)); ax.set_xticklabels([f"E{e}" for e in epochs]) | |
| ax.set_title(f"[{variant}] Closed vs Open Accuracy per Epoch", | |
| fontsize=14, fontweight="bold") | |
| ax.set_ylabel("Accuracy") | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| ax.legend(); ax.grid(True, alpha=0.3, axis="y") | |
| _finish(fig, f"{variant}_03_closed_vs_open.png") | |
| # ββ Chart 4: BERTScore + Semantic Score ββββββββββββββββββββββββββββββ | |
| bert_vals = [r["val_bert_score_raw"] for r in self.history] | |
| semantic_vals = [r["val_semantic_raw"] for r in self.history] | |
| if any(v > 0 for v in bert_vals + semantic_vals): | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| ax.plot(epochs, bert_vals, color=COLORS["bert"], linewidth=2.5, | |
| marker="o", label="BERTScore") | |
| ax.plot(epochs, semantic_vals, color=COLORS["semantic"], linewidth=2.5, | |
| marker="s", label="Semantic Score") | |
| ax.set_title(f"[{variant}] BERTScore & Semantic Score per Epoch", | |
| fontsize=14, fontweight="bold") | |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Score") | |
| ax.set_ylim(0, 1.05) | |
| ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| ax.legend(); ax.grid(True, alpha=0.3) | |
| _finish(fig, f"{variant}_04_bert_semantic.png") | |
| # ββ Print final summary table βββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "β" * 72) | |
| print(f" π TRAINING SUMMARY β {variant}") | |
| print("β" * 72) | |
| print(f" {'Epoch':>5} {'TrainLoss':>10} {'Accuracy':>9} {'F1':>7} {'BLEU-4':>7} {'Best':>5}") | |
| print("β" * 72) | |
| for r in self.history: | |
| star = " β " if r.get("best_so_far") else "" | |
| print( | |
| f" {r['epoch']:>5} {r['train_loss']:>10.4f} " | |
| f"{r['val_accuracy_normalized']:>9.2%} " | |
| f"{r['val_f1_normalized']:>7.2%} " | |
| f"{r['val_bleu4_normalized']:>7.2%}{star}" | |
| ) | |
| print("β" * 72 + "\n") | |
| return saved | |