|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from tqdm import tqdm |
|
import logging |
|
import os |
|
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class DistillationTrainer: |
|
""" |
|
Trainer for knowledge distillation from teacher model (BERT) to student model (LSTM) |
|
""" |
|
def __init__( |
|
self, |
|
teacher_model, |
|
student_model, |
|
train_loader, |
|
val_loader, |
|
test_loader=None, |
|
temperature=2.0, |
|
alpha=0.5, |
|
lr=0.001, |
|
weight_decay=1e-5, |
|
max_grad_norm=1.0, |
|
label_mapping=None, |
|
num_categories=1, |
|
num_classes=2, |
|
device=None |
|
): |
|
self.teacher_model = teacher_model |
|
self.student_model = student_model |
|
self.train_loader = train_loader |
|
self.val_loader = val_loader |
|
self.test_loader = test_loader |
|
self.temperature = temperature |
|
self.alpha = alpha |
|
self.max_grad_norm = max_grad_norm |
|
self.num_categories = num_categories |
|
self.num_classes = num_classes |
|
|
|
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.teacher_model.to(self.device) |
|
self.student_model.to(self.device) |
|
|
|
|
|
self.teacher_model.eval() |
|
|
|
|
|
self.optimizer = torch.optim.Adam( |
|
self.student_model.parameters(), |
|
lr=lr, |
|
weight_decay=weight_decay |
|
) |
|
|
|
|
|
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
self.optimizer, mode='max', factor=0.5, patience=2, verbose=True |
|
) |
|
|
|
|
|
self.ce_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
self.best_val_f1 = 0.0 |
|
self.best_model_state = None |
|
self.label_mapping = label_mapping |
|
|
|
|
|
def distillation_loss(self, student_logits, teacher_logits, labels, temperature, alpha): |
|
""" |
|
Compute the knowledge distillation loss |
|
|
|
Args: |
|
student_logits: Output from student model |
|
teacher_logits: Output from teacher model |
|
labels: Ground truth labels |
|
temperature: Temperature for softening probability distributions |
|
alpha: Weight for distillation loss vs. cross-entropy loss |
|
|
|
Returns: |
|
Combined loss |
|
""" |
|
|
|
soft_targets = F.softmax(teacher_logits / temperature, dim=1) |
|
soft_prob = F.log_softmax(student_logits / temperature, dim=1) |
|
|
|
|
|
distill_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2) |
|
|
|
|
|
if self.num_categories > 1: |
|
total_loss = 0 |
|
for i in range(self.num_categories): |
|
start_idx = i * self.num_classes |
|
end_idx = (i + 1) * self.num_classes |
|
category_outputs = student_logits[:, start_idx:end_idx] |
|
category_labels = labels[:, i] |
|
|
|
|
|
if category_labels.max() >= self.num_classes or category_labels.min() < 0: |
|
print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}") |
|
|
|
total_loss += self.ce_loss(category_outputs, category_labels) |
|
|
|
ce_loss = total_loss / self.num_categories |
|
else: |
|
ce_loss = self.ce_loss(student_logits, labels) |
|
|
|
|
|
loss = alpha * distill_loss + (1 - alpha) * ce_loss |
|
|
|
return loss |
|
|
|
def train(self, epochs, save_path='best_distilled_model.pth'): |
|
""" |
|
Train student model with knowledge distillation |
|
""" |
|
logger.info(f"Starting distillation training for {epochs} epochs") |
|
logger.info(f"Temperature: {self.temperature}, Alpha: {self.alpha}") |
|
|
|
for epoch in range(epochs): |
|
self.student_model.train() |
|
train_loss = 0.0 |
|
all_preds = [] |
|
all_labels = [] |
|
|
|
|
|
train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]") |
|
for batch in train_iterator: |
|
|
|
input_ids = batch['input_ids'].to(self.device) |
|
attention_mask = batch['attention_mask'].to(self.device) |
|
labels = batch['label'].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
teacher_logits = self.teacher_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
student_logits = self.student_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
loss = self.distillation_loss( |
|
student_logits, |
|
teacher_logits, |
|
labels, |
|
self.temperature, |
|
self.alpha |
|
) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), self.max_grad_norm) |
|
self.optimizer.step() |
|
|
|
train_loss += loss.item() |
|
|
|
|
|
if self.num_categories > 1: |
|
batch_size, total_classes = student_logits.shape |
|
if total_classes % self.num_categories != 0: |
|
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}") |
|
|
|
classes_per_group = total_classes // self.num_categories |
|
|
|
reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) |
|
|
|
|
|
preds = reshaped.argmax(dim=-1) |
|
else: |
|
_, preds = torch.max(student_logits, 1) |
|
all_preds.extend(preds.cpu().tolist()) |
|
all_labels.extend(labels.cpu().tolist()) |
|
|
|
|
|
train_iterator.set_postfix({'loss': f"{loss.item():.4f}"}) |
|
|
|
|
|
train_loss = train_loss / len(self.train_loader) |
|
if self.num_categories > 1: |
|
all_labels = np.concatenate(all_labels, axis=0) |
|
all_preds = np.concatenate(all_preds, axis=0) |
|
|
|
|
|
train_acc = accuracy_score(all_labels, all_preds) |
|
|
|
val_loss, val_acc, val_precision, val_recall, val_f1 = self.evaluate() |
|
|
|
|
|
self.scheduler.step(val_f1) |
|
|
|
|
|
if val_f1 > self.best_val_f1: |
|
self.best_val_f1 = val_f1 |
|
self.best_model_state = self.student_model.state_dict().copy() |
|
torch.save({ |
|
'model_state_dict': self.student_model.state_dict(), |
|
'label_mapping': self.label_mapping, |
|
}, save_path) |
|
logger.info(f"New best model saved with validation F1: {val_f1:.4f}, accuracy: {val_acc:.4f}") |
|
|
|
logger.info(f"Epoch {epoch+1}/{epochs}: " |
|
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " |
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}") |
|
|
|
print(f"Epoch {epoch+1}/{epochs}: ", |
|
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, ", |
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}") |
|
|
|
|
|
if self.best_model_state is not None: |
|
self.student_model.load_state_dict(self.best_model_state) |
|
logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}") |
|
|
|
|
|
if self.test_loader: |
|
test_loss, test_acc, test_precision, test_recall, test_f1 = self.evaluate(self.test_loader, "Test") |
|
logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}") |
|
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}") |
|
|
|
def evaluate(self, data_loader=None, phase="Validation", threshold=0.55): |
|
""" |
|
Evaluate the student model |
|
""" |
|
if data_loader is None: |
|
data_loader = self.val_loader |
|
|
|
self.student_model.eval() |
|
eval_loss = 0.0 |
|
all_preds = np.array([], dtype=int) |
|
all_labels = np.array([], dtype=int) |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(data_loader, desc=f"[{phase}]"): |
|
input_ids = batch['input_ids'].to(self.device) |
|
attention_mask = batch['attention_mask'].to(self.device) |
|
labels = batch['label'].to(self.device) |
|
|
|
|
|
student_logits = self.student_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
if self.num_categories > 1: |
|
total_loss = 0 |
|
for i in range(self.num_categories): |
|
start_idx = i * self.num_classes |
|
end_idx = (i + 1) * self.num_classes |
|
category_outputs = student_logits[:, start_idx:end_idx] |
|
category_labels = labels[:, i] |
|
|
|
|
|
if category_labels.max() >= self.num_classes or category_labels.min() < 0: |
|
print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}") |
|
|
|
total_loss += self.ce_loss(category_outputs, category_labels) |
|
|
|
loss = total_loss / self.num_categories |
|
else: |
|
loss = self.ce_loss(student_logits, labels) |
|
eval_loss += loss.item() |
|
|
|
|
|
if self.num_categories > 1: |
|
batch_size, total_classes = student_logits.shape |
|
if total_classes % self.num_categories != 0: |
|
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}") |
|
|
|
classes_per_group = total_classes // self.num_categories |
|
|
|
reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) |
|
probs = F.softmax(reshaped, dim=1) |
|
|
|
probs = torch.where(probs > threshold, probs, 0.0) |
|
|
|
preds = probs.argmax(dim=-1) |
|
else: |
|
_, preds = torch.max(student_logits, 1) |
|
all_preds = np.append(all_preds, preds.cpu().numpy()) |
|
all_labels = np.append(all_labels, labels.cpu().numpy()) |
|
|
|
|
|
eval_loss = eval_loss / len(data_loader) |
|
|
|
if self.num_categories > 1: |
|
|
|
all_labels = np.concatenate(all_labels, axis=0) |
|
all_preds = np.concatenate(all_preds, axis=0) |
|
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
|
|
precision = precision_score(all_labels, all_preds, average='weighted') |
|
|
|
recall = recall_score(all_labels, all_preds, average='weighted') |
|
|
|
f1 = f1_score(all_labels, all_preds, average='weighted') |
|
|
|
return eval_loss, accuracy, precision, recall, f1 |