shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import argparse
import json
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple
from torch.utils.data import DataLoader
# Assuming these are in your c1.py
from c1 import (
IMDBDataset,
TransformerClassifier,
preprocess_data,
evaluate_model,
load_imdb_texts,
MODEL_PATH,
)
# You would need to install openai: pip install openai
from openai import OpenAI
api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
api_keys = json.load(f)
openai_api_key = api_keys["openai"]
client = OpenAI(api_key=openai_api_key)
# Initialize your client (ensure your API key is in your environment variables)
def get_llm_explanation(review_text: str, true_y: int, pred_y: int) -> str:
"""
Uses an LLM to perform qualitative reasoning on why the model failed.
"""
sentiment = {0: "Negative", 1: "Positive"}
prompt = f"""
A Transformer model misclassified the following movie review.
REVIEW: "{review_text[:1000]}"
TRUE LABEL: {sentiment[true_y]}
MODEL PREDICTED: {sentiment[pred_y]}
Task: Provide a concise (2-3 sentence) explanation of why a machine learning
model might have struggled with this specific text. Mention linguistic
features like sarcasm, double negatives, mixed sentiment, or specific keywords.
"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini", # Using 4o-mini as a high-performance proxy for "mini" models
messages=[{"role": "user", "content": prompt}],
temperature=0.2
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"LLM Analysis failed: {str(e)}"
def analyze_misclassifications_on_texts(
model: torch.nn.Module,
texts: List[str],
labels: List[int],
vocab: Dict[str, int],
max_len: int,
device: torch.device,
num_examples: int = 10,
) -> List[Dict]:
"""
Identifies errors, generates LLM explanations, and returns structured results.
"""
model.eval()
sequences = preprocess_data(texts, vocab, max_len)
dataset = IMDBDataset(sequences, labels)
loader = DataLoader(dataset, batch_size=64, shuffle=False)
error_results = []
printed = 0
with torch.no_grad():
for batch_idx, (batch_seq, batch_lab) in enumerate(loader):
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
logits = model(batch_seq)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
start = batch_idx * loader.batch_size
batch_texts = texts[start:start + batch_seq.size(0)]
for text, true_y, pred_y, prob_vec in zip(
batch_texts,
batch_lab.cpu().numpy(),
preds.cpu().numpy(),
probs.cpu().numpy(),
):
if true_y != pred_y:
printed += 1
print(f"Analyzing error #{printed} with LLM...")
explanation = get_llm_explanation(text, true_y, pred_y)
error_entry = {
"example_id": printed,
"true_label": int(true_y),
"predicted_label": int(pred_y),
"confidence_neg": float(prob_vec[0]),
"confidence_pos": float(prob_vec[1]),
"text": text,
"explanation": explanation
}
error_results.append(error_entry)
# Print to console for immediate feedback
print("=" * 80)
print(f"True: {true_y} | Pred: {pred_y}")
print(f"Reasoning: {explanation}")
print("=" * 80)
if printed >= num_examples:
return error_results
return error_results
def load_trained_model_from_checkpoint(
checkpoint_path: str = MODEL_PATH,
device: torch.device | None = None,
) -> Tuple[torch.nn.Module, Dict[str, int], Dict]:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(checkpoint_path, map_location=device)
vocab = ckpt["vocab"]
config = ckpt["config"]
model = TransformerClassifier(
vocab_size=len(vocab),
d_model=config["d_model"],
num_heads=config["num_heads"],
num_layers=config["num_layers"],
d_ff=config["d_ff"],
max_len=config["max_len"],
).to(device)
model.load_state_dict(ckpt["model_state_dict"])
return model, vocab, config
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--num_examples", type=int, default=10)
parser.add_argument("--output", type=str, default="error_analysis.json")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Load Model
model, vocab, config = load_trained_model_from_checkpoint(device=device)
# 2. Load Data
texts, labels = load_imdb_texts(split=args.split)
# 3. Analyze
errors = analyze_misclassifications_on_texts(
model=model,
texts=texts,
labels=labels,
vocab=vocab,
max_len=config["max_len"],
device=device,
num_examples=args.num_examples
)
# 4. Save Results
with open(args.output, "w") as f:
json.dump(errors, f, indent=4)
print(f"\nAnalysis complete. Results saved to {args.output}")
if __name__ == "__main__":
main()