| |
| """ |
| Train NFQA Classification Model from Scratch |
| |
| Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data. |
| |
| Usage (single file with automatic splitting): |
| python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10 |
| |
| Usage (pre-split files): |
| python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10 |
| |
| Author: Ali |
| Date: December 2024 |
| """ |
|
|
| import pandas as pd |
| import numpy as np |
| import torch |
| import json |
| import argparse |
| import os |
| from collections import Counter |
| from datetime import datetime |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from transformers import ( |
| AutoTokenizer, |
| AutoConfig, |
| AutoModelForSequenceClassification, |
| get_linear_schedule_with_warmup |
| ) |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import ( |
| classification_report, |
| confusion_matrix, |
| accuracy_score, |
| f1_score |
| ) |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from tqdm import tqdm |
|
|
| |
| RANDOM_SEED = 42 |
| np.random.seed(RANDOM_SEED) |
| torch.manual_seed(RANDOM_SEED) |
|
|
| NFQA_CATEGORIES = [ |
| 'NOT-A-QUESTION', |
| 'FACTOID', |
| 'DEBATE', |
| 'EVIDENCE-BASED', |
| 'INSTRUCTION', |
| 'REASON', |
| 'EXPERIENCE', |
| 'COMPARISON' |
| ] |
|
|
| |
| LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)} |
| ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} |
|
|
|
|
| class NFQADataset(Dataset): |
| """Custom dataset for NFQA classification""" |
|
|
| def __init__(self, questions, labels, tokenizer, max_length=128): |
| self.questions = questions |
| self.labels = labels |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.questions) |
|
|
| def __getitem__(self, idx): |
| question = str(self.questions[idx]) |
| label = int(self.labels[idx]) |
|
|
| |
| encoding = self.tokenizer( |
| question, |
| add_special_tokens=True, |
| max_length=self.max_length, |
| padding='max_length', |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors='pt' |
| ) |
|
|
| return { |
| 'input_ids': encoding['input_ids'].flatten(), |
| 'attention_mask': encoding['attention_mask'].flatten(), |
| 'labels': torch.tensor(label, dtype=torch.long) |
| } |
|
|
|
|
| def train_epoch(model, train_loader, optimizer, scheduler, device): |
| """Train for one epoch""" |
| model.train() |
| total_loss = 0 |
| predictions = [] |
| true_labels = [] |
|
|
| progress_bar = tqdm(train_loader, desc="Training") |
|
|
| for batch in progress_bar: |
| |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| |
| optimizer.zero_grad() |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels |
| ) |
|
|
| loss = outputs.loss |
| total_loss += loss.item() |
|
|
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| |
| preds = torch.argmax(outputs.logits, dim=1) |
| predictions.extend(preds.cpu().numpy()) |
| true_labels.extend(labels.cpu().numpy()) |
|
|
| |
| progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
| avg_loss = total_loss / len(train_loader) |
| accuracy = accuracy_score(true_labels, predictions) |
|
|
| return avg_loss, accuracy |
|
|
|
|
| def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False): |
| """Evaluate model on validation/test set with optional detailed analysis""" |
| model.eval() |
| total_loss = 0 |
| predictions = [] |
| true_labels = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(data_loader, desc=desc): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels |
| ) |
|
|
| total_loss += outputs.loss.item() |
|
|
| preds = torch.argmax(outputs.logits, dim=1) |
| predictions.extend(preds.cpu().numpy()) |
| true_labels.extend(labels.cpu().numpy()) |
|
|
| avg_loss = total_loss / len(data_loader) |
| accuracy = accuracy_score(true_labels, predictions) |
| f1 = f1_score(true_labels, predictions, average='macro') |
|
|
| |
| if show_analysis and languages is not None: |
| print("\n" + "-"*70) |
| print("VALIDATION ANALYSIS") |
| print("-"*70) |
|
|
| |
| analyze_performance_by_category(predictions, true_labels) |
|
|
| |
| analyze_performance_by_language(predictions, true_labels, languages, top_n=5) |
|
|
| |
| analyze_language_category_combinations(predictions, true_labels, languages, top_n=10) |
|
|
| print("-"*70) |
|
|
| return avg_loss, accuracy, f1, predictions, true_labels |
|
|
|
|
| def load_data(file_path): |
| """Load annotated data from JSONL file""" |
| print(f"Loading data from: {file_path}\n") |
|
|
| try: |
| df = pd.read_json(file_path, lines=True) |
| print(f"✓ Loaded {len(df)} annotated examples") |
|
|
| |
| if 'question' not in df.columns: |
| raise ValueError("Missing 'question' column") |
|
|
| |
| if 'label_id' in df.columns: |
| label_col = 'label_id' |
| elif 'ensemble_prediction' in df.columns: |
| |
| df['label_id'] = df['ensemble_prediction'].map(LABEL2ID) |
| label_col = 'label_id' |
| elif 'label' in df.columns: |
| label_col = 'label' |
| else: |
| raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')") |
|
|
| |
| df = df.dropna(subset=['question', label_col]) |
|
|
| print(f"✓ Data cleaned: {len(df)} examples with valid labels") |
|
|
| |
| print("\nLabel distribution:") |
| label_counts = df[label_col].value_counts().sort_index() |
| for label_id, count in label_counts.items(): |
| cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}") |
| print(f" {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)") |
|
|
| |
| questions = df['question'].tolist() |
| labels = df[label_col].astype(int).tolist() |
| languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df) |
|
|
| print(f"\n✓ Prepared {len(questions)} question-label pairs") |
|
|
| return questions, labels, languages |
|
|
| except FileNotFoundError: |
| print(f"❌ Error: File not found: {file_path}") |
| raise |
| except Exception as e: |
| print(f"❌ Error loading data: {e}") |
| raise |
|
|
|
|
|
|
| def plot_training_curves(history, best_val_f1, output_dir): |
| """Plot and save training curves""" |
| fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
|
|
| epochs = range(1, len(history['train_loss']) + 1) |
|
|
| |
| axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) |
| axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) |
| axes[0].set_xlabel('Epoch') |
| axes[0].set_ylabel('Loss') |
| axes[0].set_title('Training and Validation Loss') |
| axes[0].legend() |
| axes[0].grid(True, alpha=0.3) |
|
|
| |
| axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2) |
| axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2) |
| axes[1].set_xlabel('Epoch') |
| axes[1].set_ylabel('Accuracy') |
| axes[1].set_title('Training and Validation Accuracy') |
| axes[1].legend() |
| axes[1].grid(True, alpha=0.3) |
|
|
| |
| axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2) |
| axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}') |
| axes[2].set_xlabel('Epoch') |
| axes[2].set_ylabel('F1 Score') |
| axes[2].set_title('Validation F1 Score') |
| axes[2].legend() |
| axes[2].grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plot_file = os.path.join(output_dir, 'training_curves.png') |
| plt.savefig(plot_file, dpi=300, bbox_inches='tight') |
| plt.close() |
|
|
| print(f"✓ Training curves saved to: {plot_file}") |
|
|
|
|
| def analyze_performance_by_language(predictions, true_labels, languages, top_n=10): |
| """Analyze and print performance by language""" |
| from collections import defaultdict |
|
|
| lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
| for pred, true, lang in zip(predictions, true_labels, languages): |
| lang_stats[lang]['total'] += 1 |
| if pred == true: |
| lang_stats[lang]['correct'] += 1 |
|
|
| |
| lang_accuracies = [] |
| for lang, stats in lang_stats.items(): |
| if stats['total'] >= 5: |
| acc = stats['correct'] / stats['total'] |
| lang_accuracies.append({ |
| 'language': lang, |
| 'accuracy': acc, |
| 'correct': stats['correct'], |
| 'total': stats['total'], |
| 'errors': stats['total'] - stats['correct'] |
| }) |
|
|
| lang_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
| print(f"\n{'='*70}") |
| print(f"WORST {top_n} LANGUAGES (with >= 5 examples)") |
| print(f"{'='*70}") |
| print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") |
| print(f"{'-'*70}") |
|
|
| for item in lang_accuracies[:top_n]: |
| print(f"{item['language']:<12} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") |
|
|
| return lang_stats, lang_accuracies |
|
|
|
|
| def analyze_performance_by_category(predictions, true_labels): |
| """Analyze and print performance by category""" |
| from collections import defaultdict |
|
|
| cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
| for pred, true in zip(predictions, true_labels): |
| cat_stats[true]['total'] += 1 |
| if pred == true: |
| cat_stats[true]['correct'] += 1 |
|
|
| cat_accuracies = [] |
| for cat_id, stats in cat_stats.items(): |
| acc = stats['correct'] / stats['total'] |
| cat_accuracies.append({ |
| 'category': ID2LABEL[cat_id], |
| 'accuracy': acc, |
| 'correct': stats['correct'], |
| 'total': stats['total'], |
| 'errors': stats['total'] - stats['correct'] |
| }) |
|
|
| cat_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
| print(f"\n{'='*70}") |
| print(f"PERFORMANCE BY CATEGORY") |
| print(f"{'='*70}") |
| print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") |
| print(f"{'-'*70}") |
|
|
| for item in cat_accuracies: |
| print(f"{item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") |
|
|
| return cat_stats, cat_accuracies |
|
|
|
|
| def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15): |
| """Analyze performance by (language, category) combinations""" |
| from collections import defaultdict |
|
|
| combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
| for pred, true, lang in zip(predictions, true_labels, languages): |
| key = (lang, ID2LABEL[true]) |
| combo_stats[key]['total'] += 1 |
| if pred == true: |
| combo_stats[key]['correct'] += 1 |
|
|
| combo_accuracies = [] |
| for (lang, cat), stats in combo_stats.items(): |
| if stats['total'] >= 3: |
| acc = stats['correct'] / stats['total'] |
| combo_accuracies.append({ |
| 'language': lang, |
| 'category': cat, |
| 'accuracy': acc, |
| 'correct': stats['correct'], |
| 'total': stats['total'], |
| 'errors': stats['total'] - stats['correct'] |
| }) |
|
|
| combo_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
| print(f"\n{'='*80}") |
| print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)") |
| print(f"{'='*80}") |
| print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}") |
| print(f"{'-'*80}") |
|
|
| for item in combo_accuracies[:top_n]: |
| print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>6} {item['total']:>6}") |
|
|
| return combo_stats, combo_accuracies |
|
|
|
|
| def plot_confusion_matrix(test_true, test_preds, output_dir): |
| """Plot and save confusion matrix""" |
| cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES)))) |
|
|
| plt.figure(figsize=(12, 10)) |
| sns.heatmap( |
| cm, |
| annot=True, |
| fmt='d', |
| cmap='Blues', |
| xticklabels=NFQA_CATEGORIES, |
| yticklabels=NFQA_CATEGORIES, |
| cbar_kws={'label': 'Count'} |
| ) |
| plt.xlabel('Predicted Category') |
| plt.ylabel('True Category') |
| plt.title('Confusion Matrix - Test Set') |
| plt.xticks(rotation=45, ha='right') |
| plt.yticks(rotation=0) |
| plt.tight_layout() |
|
|
| cm_file = os.path.join(output_dir, 'confusion_matrix.png') |
| plt.savefig(cm_file, dpi=300, bbox_inches='tight') |
| plt.close() |
|
|
| print(f"✓ Confusion matrix saved to: {cm_file}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Train NFQA Classification Model') |
|
|
| |
| parser.add_argument('--input', type=str, |
| help='Input JSONL file with annotated data (will be split automatically)') |
| parser.add_argument('--train', type=str, |
| help='Training set JSONL file (use with --val and --test)') |
| parser.add_argument('--val', type=str, |
| help='Validation set JSONL file (use with --train and --test)') |
| parser.add_argument('--test', type=str, |
| help='Test set JSONL file (use with --train and --val)') |
| parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained', |
| help='Output directory for model and results') |
|
|
| |
| parser.add_argument('--model-name', type=str, default='xlm-roberta-base', |
| help='Pretrained model name (default: xlm-roberta-base)') |
| parser.add_argument('--max-length', type=int, default=128, |
| help='Maximum sequence length (default: 128)') |
|
|
| |
| parser.add_argument('--batch-size', type=int, default=16, |
| help='Batch size (default: 16)') |
| parser.add_argument('--epochs', type=int, default=10, |
| help='Number of epochs (default: 10)') |
| parser.add_argument('--learning-rate', type=float, default=2e-5, |
| help='Learning rate (default: 2e-5)') |
| parser.add_argument('--warmup-ratio', type=float, default=0.1, |
| help='Fraction of total training steps used for warmup (default: 0.1)') |
| parser.add_argument('--weight-decay', type=float, default=0.01, |
| help='Weight decay (default: 0.01)') |
| parser.add_argument('--dropout', type=float, default=0.1, |
| help='Dropout probability (default: 0.1)') |
|
|
| |
| parser.add_argument('--test-size', type=float, default=0.2, |
| help='Test set size (default: 0.2)') |
| parser.add_argument('--val-size', type=float, default=0.1, |
| help='Validation set size (default: 0.1)') |
|
|
| |
| parser.add_argument('--device', type=str, default='auto', |
| help='Device to use: cuda, cpu, or auto (default: auto)') |
|
|
| args = parser.parse_args() |
|
|
| |
| has_single_input = args.input is not None |
| has_split_inputs = all([args.train, args.val, args.test]) |
|
|
| if not has_single_input and not has_split_inputs: |
| parser.error("Either --input OR (--train, --val, --test) must be provided") |
|
|
| if has_single_input and has_split_inputs: |
| parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.") |
|
|
| |
| print("="*80) |
| print("NFQA MODEL TRAINING") |
| print("="*80) |
| if has_single_input: |
| print(f"Input file: {args.input}") |
| print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})") |
| else: |
| print(f"Train file: {args.train}") |
| print(f"Val file: {args.val}") |
| print(f"Test file: {args.test}") |
| print(f"Data splitting: manual (pre-split)") |
| print(f"Output directory: {args.output_dir}") |
| print(f"Model: {args.model_name}") |
| print(f"Epochs: {args.epochs}") |
| print(f"Batch size: {args.batch_size}") |
| print(f"Learning rate: {args.learning_rate}") |
| print(f"Max length: {args.max_length}") |
| print(f"Weight decay: {args.weight_decay}") |
| print(f"Warmup ratio: {args.warmup_ratio}") |
| print(f"Dropout: {args.dropout}") |
| print("="*80 + "\n") |
|
|
| |
| if args.device == 'auto': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| device = torch.device(args.device) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(RANDOM_SEED) |
|
|
| print(f"Device: {device}") |
| print(f"PyTorch version: {torch.__version__}") |
| if torch.cuda.is_available(): |
| print(f"CUDA device: {torch.cuda.get_device_name(0)}\n") |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| if has_single_input: |
| |
| questions, labels, languages = load_data(args.input) |
|
|
| |
| from sklearn.model_selection import train_test_split |
| |
| train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split( |
| questions, labels, languages, |
| test_size=args.test_size, |
| random_state=RANDOM_SEED, |
| stratify=labels |
| ) |
|
|
| |
| train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split( |
| train_val_questions, train_val_labels, train_val_langs, |
| test_size=args.val_size / (1 - args.test_size), |
| random_state=RANDOM_SEED, |
| stratify=train_val_labels |
| ) |
|
|
| print(f"\nData splits:") |
| print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)") |
| print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)") |
| print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)") |
| print(f" Total: {len(questions):4d} examples") |
| else: |
| |
| print("Loading pre-split datasets...\n") |
| train_questions, train_labels, train_langs = load_data(args.train) |
| val_questions, val_labels, val_langs = load_data(args.val) |
| test_questions, test_labels, test_langs = load_data(args.test) |
|
|
| |
| total_examples = len(train_questions) + len(val_questions) + len(test_questions) |
| print(f"\nData splits:") |
| print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)") |
| print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)") |
| print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)") |
| print(f" Total: {total_examples:4d} examples") |
|
|
| |
| print("\nClass distribution per split:") |
| for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]: |
| counts = Counter(split_labels) |
| print(f"\n{split_name}:") |
| for label_id in sorted(counts.keys()): |
| cat_name = ID2LABEL[label_id] |
| print(f" {cat_name:20s}: {counts[label_id]:3d}") |
|
|
| |
| print(f"\nLoading tokenizer: {args.model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| print("✓ Tokenizer loaded") |
|
|
| print(f"\nLoading model: {args.model_name}") |
|
|
| |
| config = AutoConfig.from_pretrained(args.model_name) |
| config.num_labels = len(NFQA_CATEGORIES) |
| config.id2label = ID2LABEL |
| config.label2id = LABEL2ID |
| config.hidden_dropout_prob = args.dropout |
| config.attention_probs_dropout_prob = args.dropout |
| config.classifier_dropout = args.dropout |
|
|
| |
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_name, |
| config=config |
| ) |
| model.to(device) |
|
|
| print(f"✓ Model loaded") |
| print(f" Number of parameters: {sum(p.numel() for p in model.parameters()):,}") |
| print(f" Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") |
|
|
| |
| print("\nCreating datasets...") |
| train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length) |
| val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length) |
| test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size) |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size) |
|
|
| print(f"✓ Datasets created") |
| print(f" Train: {len(train_dataset)} examples ({len(train_loader)} batches)") |
| print(f" Val: {len(val_dataset)} examples ({len(val_loader)} batches)") |
| print(f" Test: {len(test_dataset)} examples ({len(test_loader)} batches)") |
|
|
| |
| optimizer = AdamW( |
| model.parameters(), |
| lr=args.learning_rate, |
| weight_decay=args.weight_decay |
| ) |
|
|
| total_steps = len(train_loader) * args.epochs |
| warmup_steps = int(args.warmup_ratio * total_steps) |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=total_steps |
| ) |
|
|
| print(f"\n✓ Optimizer and scheduler configured") |
| print(f" Total training steps: {total_steps}") |
| print(f" Warmup steps: {warmup_steps} ({args.warmup_ratio*100:.0f}% of total)") |
|
|
| |
| history = { |
| 'train_loss': [], |
| 'train_accuracy': [], |
| 'val_loss': [], |
| 'val_accuracy': [], |
| 'val_f1': [] |
| } |
|
|
| best_val_f1 = 0 |
| best_epoch = 0 |
|
|
| print("\n" + "="*80) |
| print("STARTING TRAINING") |
| print("="*80 + "\n") |
|
|
| for epoch in range(args.epochs): |
| print(f"\nEpoch {epoch + 1}/{args.epochs}") |
| print("-" * 80) |
|
|
| |
| train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device) |
|
|
| |
| val_loss, val_acc, val_f1, val_preds, val_true = evaluate( |
| model, val_loader, device, |
| languages=val_langs, |
| desc="Validating", |
| show_analysis=False |
| ) |
|
|
| |
| history['train_loss'].append(train_loss) |
| history['train_accuracy'].append(train_acc) |
| history['val_loss'].append(val_loss) |
| history['val_accuracy'].append(val_acc) |
| history['val_f1'].append(val_f1) |
|
|
| |
| print(f"\nEpoch {epoch + 1} Summary:") |
| print(f" Train Loss: {train_loss:.4f}") |
| print(f" Train Accuracy: {train_acc:.4f}") |
| print(f" Val Loss: {val_loss:.4f}") |
| print(f" Val Accuracy: {val_acc:.4f}") |
| print(f" Val F1 (Macro): {val_f1:.4f}") |
|
|
| |
| if val_f1 > best_val_f1: |
| best_val_f1 = val_f1 |
| best_epoch = epoch + 1 |
|
|
| |
| model_path = os.path.join(args.output_dir, 'best_model') |
| model.save_pretrained(model_path) |
| tokenizer.save_pretrained(model_path) |
|
|
| print(f" ✓ New best model saved! (F1: {val_f1:.4f})") |
|
|
| print("\n" + "="*80) |
| print("TRAINING COMPLETE") |
| print("="*80) |
| print(f"Best epoch: {best_epoch}") |
| print(f"Best validation F1: {best_val_f1:.4f}") |
| print("="*80) |
|
|
| |
| history_file = os.path.join(args.output_dir, 'training_history.json') |
| with open(history_file, 'w') as f: |
| json.dump(history, f, indent=2) |
| print(f"\n✓ Training history saved to: {history_file}") |
|
|
| |
| final_model_path = os.path.join(args.output_dir, 'final_model') |
| model.save_pretrained(final_model_path) |
| tokenizer.save_pretrained(final_model_path) |
| print(f"✓ Final model saved to: {final_model_path}") |
|
|
| |
| plot_training_curves(history, best_val_f1, args.output_dir) |
|
|
| |
| print("\nLoading best model for final evaluation...") |
| best_model_path = os.path.join(args.output_dir, 'best_model') |
| model = AutoModelForSequenceClassification.from_pretrained(best_model_path) |
| model.to(device) |
|
|
| test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing") |
|
|
| print("\n" + "="*80) |
| print("FINAL TEST SET RESULTS") |
| print("="*80) |
| print(f"Test Loss: {test_loss:.4f}") |
| print(f"Test Accuracy: {test_acc:.4f}") |
| print(f"Test F1 (Macro): {test_f1:.4f}") |
| print("="*80) |
|
|
| |
| print("\n" + "="*80) |
| print("PER-CATEGORY PERFORMANCE") |
| print("="*80 + "\n") |
|
|
| report = classification_report( |
| test_true, |
| test_preds, |
| labels=list(range(len(NFQA_CATEGORIES))), |
| target_names=NFQA_CATEGORIES, |
| zero_division=0 |
| ) |
| print(report) |
|
|
| |
| report_file = os.path.join(args.output_dir, 'classification_report.txt') |
| with open(report_file, 'w') as f: |
| f.write(report) |
| print(f"✓ Classification report saved to: {report_file}") |
|
|
| |
| plot_confusion_matrix(test_true, test_preds, args.output_dir) |
|
|
| |
| print("\n" + "="*80) |
| print("DETAILED PERFORMANCE ANALYSIS") |
| print("="*80) |
|
|
| |
| analyze_performance_by_category(test_preds, test_true) |
|
|
| |
| analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10) |
|
|
| |
| analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15) |
|
|
| print("\n" + "="*80) |
|
|
| |
| test_results = { |
| 'test_loss': float(test_loss), |
| 'test_accuracy': float(test_acc), |
| 'test_f1_macro': float(test_f1), |
| 'best_epoch': int(best_epoch), |
| 'best_val_f1': float(best_val_f1), |
| 'num_train_examples': len(train_questions), |
| 'num_val_examples': len(val_questions), |
| 'num_test_examples': len(test_questions), |
| 'config': { |
| 'model_name': args.model_name, |
| 'max_length': args.max_length, |
| 'batch_size': args.batch_size, |
| 'learning_rate': args.learning_rate, |
| 'num_epochs': args.epochs, |
| 'warmup_ratio': args.warmup_ratio, |
| 'warmup_steps': warmup_steps, |
| 'weight_decay': args.weight_decay, |
| 'dropout': args.dropout, |
| 'data_source': 'pre-split' if has_split_inputs else 'single_file', |
| 'train_file': args.train if has_split_inputs else args.input, |
| 'val_file': args.val if has_split_inputs else None, |
| 'test_file': args.test if has_split_inputs else None, |
| 'auto_split': not has_split_inputs, |
| 'test_size': args.test_size if not has_split_inputs else None, |
| 'val_size': args.val_size if not has_split_inputs else None |
| }, |
| 'timestamp': datetime.now().isoformat() |
| } |
|
|
| results_file = os.path.join(args.output_dir, 'test_results.json') |
| with open(results_file, 'w') as f: |
| json.dump(test_results, f, indent=2) |
| print(f"✓ Test results saved to: {results_file}") |
|
|
| |
| print("\n" + "="*80) |
| print("TRAINING SUMMARY") |
| print("="*80) |
| print(f"\nModel: {args.model_name}") |
| print(f"Training examples: {len(train_questions)}") |
| print(f"Validation examples: {len(val_questions)}") |
| print(f"Test examples: {len(test_questions)}") |
| print(f"\nBest epoch: {best_epoch}/{args.epochs}") |
| print(f"Best validation F1: {best_val_f1:.4f}") |
| print(f"\nFinal test results:") |
| print(f" Accuracy: {test_acc:.4f}") |
| print(f" F1 Score (Macro): {test_f1:.4f}") |
| print(f"\nModel saved to: {args.output_dir}") |
| print(f"\nGenerated files:") |
| print(f" - best_model/ (best checkpoint)") |
| print(f" - final_model/ (last epoch)") |
| print(f" - training_history.json") |
| print(f" - training_curves.png") |
| print(f" - test_results.json") |
| print(f" - classification_report.txt") |
| print(f" - confusion_matrix.png") |
| print("\n" + "="*80) |
| print("✅ Training complete! Model ready for deployment.") |
| print("="*80) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|