| | """ |
| | Evaluation and qualitative error analysis helpers for the IMDB Transformer model. |
| | |
| | This module is separate from `c1.py` and focuses only on: |
| | - Loading a previously trained model from disk. |
| | - Evaluating it on an IMDB split. |
| | - Inspecting misclassified examples for qualitative error analysis. |
| | """ |
| |
|
| | from typing import Dict, List, Tuple |
| |
|
| | import argparse |
| | import os |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| |
|
| | from c1 import ( |
| | IMDBDataset, |
| | TransformerClassifier, |
| | preprocess_data, |
| | evaluate_model, |
| | load_imdb_texts, |
| | ) |
| |
|
| | |
| | SAVE_DIR = os.path.join(".", "saved_model") |
| | MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt") |
| |
|
| |
|
| | def analyze_misclassifications_on_texts( |
| | model: torch.nn.Module, |
| | texts: List[str], |
| | labels: List[int], |
| | vocab: Dict[str, int], |
| | max_len: int, |
| | device: torch.device, |
| | num_examples: int = 5, |
| | ) -> None: |
| | """ |
| | Inspect concrete examples where the model makes mistakes to understand |
| | *why* it fails and how to improve it. |
| | |
| | How to read the output (practical guidance): |
| | - Start with the true vs. predicted label: |
| | - For each misclassified review, ask whether the ground-truth label |
| | actually matches the human-intuitive sentiment. Occasional noisy |
| | labels are common in IMDB-style datasets. |
| | - Look at the confidence vector: |
| | - Very confident but wrong predictions often indicate *systematic bias* |
| | (e.g., the model over-trusts certain keywords like "great", "worst"). |
| | - Low-confidence errors may simply reflect inherently ambiguous reviews. |
| | - Scan the text content: |
| | - Check for **rare or domain-specific words** (brand names, slang, |
| | technical jargon) that might not appear often enough in training. |
| | - Look for **negation patterns** ("not good", "hardly bad", "no longer |
| | terrible") where bag-of-words style cues can mislead attention. |
| | - Notice **mixed sentiment** or **topic vs. opinion** separation |
| | (e.g., long plot summary plus a brief opinion at the end). |
| | - Pay attention to **sarcasm and irony**, which are notoriously hard |
| | for models relying mostly on local lexical cues. |
| | - Compare several misclassified examples: |
| | - If you see many errors with long reviews, consider increasing MAX_LEN |
| | or using a deeper model. |
| | - If errors cluster around subtle, low-intensity sentiment, you may need |
| | more expressive capacity (higher d_model / more layers) or additional |
| | training data. |
| | |
| | Based on these observations you can propose targeted improvements, such as: |
| | - Expanding the vocabulary or switching to subword tokenization. |
| | - Adjusting hyperparameters (sequence length, model size). |
| | - Incorporating pre-trained language models for richer semantics. |
| | """ |
| | model.eval() |
| | sequences = preprocess_data(texts, vocab, max_len) |
| | dataset = IMDBDataset(sequences, labels) |
| | loader = DataLoader(dataset, batch_size=64, shuffle=False) |
| |
|
| | printed = 0 |
| | with torch.no_grad(): |
| | for batch_idx, (batch_seq, batch_lab) in enumerate(loader): |
| | batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) |
| | logits = model(batch_seq) |
| | probs = F.softmax(logits, dim=1) |
| | preds = torch.argmax(probs, dim=1) |
| |
|
| | start = batch_idx * loader.batch_size |
| | end = start + batch_seq.size(0) |
| | batch_texts = texts[start:end] |
| |
|
| | for text, true_y, pred_y, prob_vec in zip( |
| | batch_texts, |
| | batch_lab.cpu().numpy(), |
| | preds.cpu().numpy(), |
| | probs.cpu().numpy(), |
| | ): |
| | if true_y != pred_y: |
| | printed += 1 |
| | print("=" * 80) |
| | print(f"Misclassified example #{printed}") |
| | print(f"True label : {true_y} (0=neg, 1=pos)") |
| | print(f"Predicted label: {pred_y}") |
| | print(f"Model confidence (class 0, class 1): {prob_vec}") |
| |
|
| | if printed >= num_examples: |
| | print("=" * 80) |
| | print( |
| | f"Displayed the first {num_examples} misclassified " |
| | "examples on this split." |
| | ) |
| | return |
| |
|
| | if printed == 0: |
| | print("No misclassified examples found on this split (perfect accuracy).") |
| |
|
| |
|
| | def load_trained_model_from_checkpoint( |
| | checkpoint_path: str = MODEL_PATH, |
| | device: torch.device | None = None, |
| | ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: |
| | """ |
| | Load a previously trained Transformer model, along with its vocabulary |
| | and configuration, from the checkpoint saved by `c1.py`. |
| | |
| | Returns: |
| | model: Loaded TransformerClassifier on the requested device. |
| | vocab: Token-to-index mapping used during training. |
| | config: Hyperparameter/config dictionary saved in the checkpoint. |
| | """ |
| | if device is None: |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | ckpt = torch.load(checkpoint_path, map_location=device) |
| | vocab: Dict[str, int] = ckpt["vocab"] |
| | config: Dict = ckpt["config"] |
| |
|
| | model = TransformerClassifier( |
| | vocab_size=len(vocab), |
| | d_model=config["d_model"], |
| | num_heads=config["num_heads"], |
| | num_layers=config["num_layers"], |
| | d_ff=config["d_ff"], |
| | max_len=config["max_len"], |
| | ).to(device) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model.eval() |
| |
|
| | return model, vocab, config |
| |
|
| |
|
| | def evaluate_and_analyze_saved_model( |
| | split: str = "test", |
| | checkpoint_path: str | None = None, |
| | model_size: str = "medium", |
| | num_examples: int = 5, |
| | device: torch.device | None = None, |
| | ) -> None: |
| | """ |
| | High-level helper that: |
| | 1) Loads the trained model/vocab/config from disk. |
| | 2) Evaluates it on the requested IMDB split. |
| | 3) Runs qualitative error analysis on that split. |
| | """ |
| | if device is None: |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | if checkpoint_path is None: |
| | checkpoint_path = os.path.join(SAVE_DIR, f"transformer_imdb_{model_size}.pt") |
| |
|
| | print(f"Loading trained model from: {checkpoint_path}") |
| | model, vocab, config = load_trained_model_from_checkpoint( |
| | checkpoint_path=checkpoint_path, |
| | device=device, |
| | ) |
| |
|
| | print(f"Evaluating on IMDB '{split}' split...") |
| | texts, labels = load_imdb_texts(split=split) |
| | sequences = preprocess_data(texts, vocab, config["max_len"]) |
| | dataset = IMDBDataset(sequences, labels) |
| | loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False) |
| |
|
| | metrics = evaluate_model(model, loader, device) |
| | print("Evaluation metrics:", metrics) |
| |
|
| | print("\nRunning qualitative error analysis...") |
| | analyze_misclassifications_on_texts( |
| | model=model, |
| | texts=texts, |
| | labels=labels, |
| | vocab=vocab, |
| | max_len=config["max_len"], |
| | device=device, |
| | num_examples=num_examples, |
| | ) |
| |
|
| |
|
| | def main(): |
| | """ |
| | Command-line interface for evaluation and analysis utilities. |
| | |
| | Example: |
| | # Evaluate medium model on IMDB test split and show 5 errors |
| | python c1_analysis.py --split test --model_size medium --num_examples 5 |
| | """ |
| | parser = argparse.ArgumentParser(description="IMDB Transformer evaluation and analysis utilities") |
| | parser.add_argument( |
| | "--split", |
| | type=str, |
| | default="test", |
| | help="IMDB split to evaluate on (e.g., 'test', 'train').", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint", |
| | type=str, |
| | default=None, |
| | help=( |
| | "Optional explicit checkpoint path. If provided, this overrides " |
| | "--model_size." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--model_size", |
| | type=str, |
| | choices=["small", "medium", "large"], |
| | default="medium", |
| | help=( |
| | "Model size to load from saved checkpoints. Used when --checkpoint " |
| | "is not provided." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--num_examples", |
| | type=int, |
| | default=5, |
| | help="Number of misclassified examples to print in error analysis.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | evaluate_and_analyze_saved_model( |
| | split=args.split, |
| | checkpoint_path=args.checkpoint, |
| | model_size=args.model_size, |
| | num_examples=args.num_examples, |
| | device=device, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|