| | import argparse |
| | import json |
| | import torch |
| | import torch.nn.functional as F |
| | from typing import Dict, List, Tuple |
| | from torch.utils.data import DataLoader |
| |
|
| | |
| | from c1 import ( |
| | IMDBDataset, |
| | TransformerClassifier, |
| | preprocess_data, |
| | evaluate_model, |
| | load_imdb_texts, |
| | MODEL_PATH, |
| | ) |
| |
|
| | |
| | from openai import OpenAI |
| | api_file = "/home/mshahidul/api_new.json" |
| | with open(api_file, "r") as f: |
| | api_keys = json.load(f) |
| | openai_api_key = api_keys["openai"] |
| |
|
| | client = OpenAI(api_key=openai_api_key) |
| | |
| |
|
| | def get_llm_explanation(review_text: str, true_y: int, pred_y: int) -> str: |
| | """ |
| | Uses an LLM to perform qualitative reasoning on why the model failed. |
| | """ |
| | sentiment = {0: "Negative", 1: "Positive"} |
| | |
| | prompt = f""" |
| | A Transformer model misclassified the following movie review. |
| | |
| | REVIEW: "{review_text[:1000]}" |
| | TRUE LABEL: {sentiment[true_y]} |
| | MODEL PREDICTED: {sentiment[pred_y]} |
| | |
| | Task: Provide a concise (2-3 sentence) explanation of why a machine learning |
| | model might have struggled with this specific text. Mention linguistic |
| | features like sarcasm, double negatives, mixed sentiment, or specific keywords. |
| | """ |
| |
|
| | try: |
| | response = client.chat.completions.create( |
| | model="gpt-4o-mini", |
| | messages=[{"role": "user", "content": prompt}], |
| | temperature=0.2 |
| | ) |
| | return response.choices[0].message.content.strip() |
| | except Exception as e: |
| | return f"LLM Analysis failed: {str(e)}" |
| |
|
| | def analyze_misclassifications_on_texts( |
| | model: torch.nn.Module, |
| | texts: List[str], |
| | labels: List[int], |
| | vocab: Dict[str, int], |
| | max_len: int, |
| | device: torch.device, |
| | num_examples: int = 10, |
| | ) -> List[Dict]: |
| | """ |
| | Identifies errors, generates LLM explanations, and returns structured results. |
| | """ |
| | model.eval() |
| | sequences = preprocess_data(texts, vocab, max_len) |
| | dataset = IMDBDataset(sequences, labels) |
| | loader = DataLoader(dataset, batch_size=64, shuffle=False) |
| |
|
| | error_results = [] |
| | printed = 0 |
| | |
| | with torch.no_grad(): |
| | for batch_idx, (batch_seq, batch_lab) in enumerate(loader): |
| | batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) |
| | logits = model(batch_seq) |
| | probs = F.softmax(logits, dim=1) |
| | preds = torch.argmax(probs, dim=1) |
| |
|
| | start = batch_idx * loader.batch_size |
| | batch_texts = texts[start:start + batch_seq.size(0)] |
| |
|
| | for text, true_y, pred_y, prob_vec in zip( |
| | batch_texts, |
| | batch_lab.cpu().numpy(), |
| | preds.cpu().numpy(), |
| | probs.cpu().numpy(), |
| | ): |
| | if true_y != pred_y: |
| | printed += 1 |
| | |
| | print(f"Analyzing error #{printed} with LLM...") |
| | explanation = get_llm_explanation(text, true_y, pred_y) |
| | |
| | error_entry = { |
| | "example_id": printed, |
| | "true_label": int(true_y), |
| | "predicted_label": int(pred_y), |
| | "confidence_neg": float(prob_vec[0]), |
| | "confidence_pos": float(prob_vec[1]), |
| | "text": text, |
| | "explanation": explanation |
| | } |
| | error_results.append(error_entry) |
| |
|
| | |
| | print("=" * 80) |
| | print(f"True: {true_y} | Pred: {pred_y}") |
| | print(f"Reasoning: {explanation}") |
| | print("=" * 80) |
| |
|
| | if printed >= num_examples: |
| | return error_results |
| |
|
| | return error_results |
| |
|
| | def load_trained_model_from_checkpoint( |
| | checkpoint_path: str = MODEL_PATH, |
| | device: torch.device | None = None, |
| | ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: |
| | if device is None: |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | ckpt = torch.load(checkpoint_path, map_location=device) |
| | vocab = ckpt["vocab"] |
| | config = ckpt["config"] |
| |
|
| | model = TransformerClassifier( |
| | vocab_size=len(vocab), |
| | d_model=config["d_model"], |
| | num_heads=config["num_heads"], |
| | num_layers=config["num_layers"], |
| | d_ff=config["d_ff"], |
| | max_len=config["max_len"], |
| | ).to(device) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | return model, vocab, config |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--split", type=str, default="test") |
| | parser.add_argument("--num_examples", type=int, default=10) |
| | parser.add_argument("--output", type=str, default="error_analysis.json") |
| | args = parser.parse_args() |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | model, vocab, config = load_trained_model_from_checkpoint(device=device) |
| |
|
| | |
| | texts, labels = load_imdb_texts(split=args.split) |
| |
|
| | |
| | errors = analyze_misclassifications_on_texts( |
| | model=model, |
| | texts=texts, |
| | labels=labels, |
| | vocab=vocab, |
| | max_len=config["max_len"], |
| | device=device, |
| | num_examples=args.num_examples |
| | ) |
| |
|
| | |
| | with open(args.output, "w") as f: |
| | json.dump(errors, f, indent=4) |
| | print(f"\nAnalysis complete. Results saved to {args.output}") |
| |
|
| | if __name__ == "__main__": |
| | main() |