|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from datasets import load_dataset
|
|
from typing import List, Optional
|
|
import argparse
|
|
import os
|
|
import json
|
|
import jsonlines
|
|
from tqdm import tqdm
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, inputs, labels):
|
|
self.inputs = inputs
|
|
self.labels = labels
|
|
|
|
def __len__(self):
|
|
return len(self.inputs)
|
|
|
|
def __getitem__(self, idx):
|
|
return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
|
|
|
|
def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
|
|
dataset = load_dataset(dataset_name, config)
|
|
if queries:
|
|
def filter_func(examples):
|
|
return any(query.lower() in examples["text"].lower() for query in queries)
|
|
dataset = dataset.filter(filter_func, batched=True)
|
|
return dataset
|
|
|
|
def prepare_data(tokenizer, dataset, max_length, batch_size):
|
|
|
|
tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
|
|
|
|
custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
|
|
|
|
|
|
train_size = int(0.9 * len(custom_dataset))
|
|
val_size = len(custom_dataset) - train_size
|
|
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
|
|
|
return train_loader, val_loader
|
|
|
|
def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
|
|
teacher.eval()
|
|
student.train()
|
|
total_loss = 0
|
|
|
|
for batch in tqdm(data_loader, desc="Training"):
|
|
inputs = batch["input_ids"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
|
|
with autocast():
|
|
with torch.no_grad():
|
|
teacher_outputs = teacher(inputs).logits
|
|
teacher_logits = teacher_outputs / temperature
|
|
|
|
student_outputs = student(inputs).logits
|
|
student_logits = student_outputs / temperature
|
|
|
|
|
|
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
|
|
loss = loss * (temperature ** 2)
|
|
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(data_loader)
|
|
return avg_loss
|
|
|
|
def validate(teacher, student, data_loader, criterion, temperature=2.0):
|
|
teacher.eval()
|
|
student.eval()
|
|
total_loss = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(data_loader, desc="Validation"):
|
|
inputs = batch["input_ids"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
|
|
teacher_outputs = teacher(inputs).logits
|
|
teacher_logits = teacher_outputs / temperature
|
|
|
|
student_outputs = student(inputs).logits
|
|
student_logits = student_outputs / temperature
|
|
|
|
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
|
|
loss = loss * (temperature ** 2)
|
|
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(data_loader)
|
|
return avg_loss
|
|
|
|
def save_checkpoint(state, save_dir, epoch):
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
|
|
torch.save(state, checkpoint_path)
|
|
print(f"Checkpoint saved at {checkpoint_path}")
|
|
|
|
def load_checkpoint(model, optimizer, scheduler, scaler, save_dir, epoch):
|
|
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
|
|
if os.path.isfile(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
print(f"Loaded checkpoint from {checkpoint_path}")
|
|
else:
|
|
print(f"No checkpoint found at {checkpoint_path}")
|
|
|
|
def distill_model(
|
|
teacher_model_name: str,
|
|
student_model_name: str,
|
|
dataset_name: str,
|
|
config: str,
|
|
distill_full_model: bool = True,
|
|
query_terms: Optional[List[str]] = None,
|
|
num_epochs: int = 3,
|
|
batch_size: int = 4,
|
|
max_length: int = 128,
|
|
learning_rate: float = 5e-5,
|
|
temperature: float = 2.0,
|
|
save_path: str = "./distilled_model",
|
|
log_dir: str = "./logs",
|
|
checkpoint_dir: str = "./checkpoints",
|
|
early_stopping_patience: int = 3
|
|
):
|
|
|
|
writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
|
|
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
|
|
|
|
|
|
for param in teacher.parameters():
|
|
param.requires_grad = False
|
|
|
|
|
|
if distill_full_model:
|
|
dataset = load_dataset(dataset_name, config)
|
|
else:
|
|
dataset = load_filtered_dataset(dataset_name, config, query_terms)
|
|
|
|
train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
|
|
|
|
|
|
optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
scaler = GradScaler()
|
|
|
|
|
|
criterion = nn.KLDivLoss(reduction="batchmean")
|
|
|
|
best_val_loss = float('inf')
|
|
epochs_no_improve = 0
|
|
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
print(f"\nEpoch {epoch}/{num_epochs}")
|
|
print("-" * 20)
|
|
|
|
|
|
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
|
|
print(f"Training Loss: {train_loss:.4f}")
|
|
writer.add_scalar("Loss/Train", train_loss, epoch)
|
|
|
|
|
|
val_loss = validate(teacher, student, val_loader, criterion, temperature)
|
|
print(f"Validation Loss: {val_loss:.4f}")
|
|
writer.add_scalar("Loss/Validation", val_loss, epoch)
|
|
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
epochs_no_improve = 0
|
|
|
|
save_checkpoint({
|
|
'epoch': epoch,
|
|
'model_state_dict': student.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'scheduler_state_dict': scheduler.state_dict(),
|
|
'scaler_state_dict': scaler.state_dict(),
|
|
'best_val_loss': best_val_loss
|
|
}, checkpoint_dir, epoch)
|
|
|
|
student.save_pretrained(save_path)
|
|
tokenizer.save_pretrained(save_path)
|
|
print(f"Best model saved at epoch {epoch}")
|
|
else:
|
|
epochs_no_improve += 1
|
|
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
|
|
if epochs_no_improve >= early_stopping_patience:
|
|
print("Early stopping triggered")
|
|
break
|
|
|
|
|
|
scheduler.step()
|
|
|
|
writer.close()
|
|
print("\nDistillation completed.")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one.")
|
|
parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
|
|
parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
|
|
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
|
|
parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
|
|
parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill the full model or not")
|
|
parser.add_argument("--query_terms", type=str, nargs="+", help="Query terms for filtering the dataset")
|
|
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
|
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
|
parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
|
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
|
|
parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
|
|
parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
|
|
parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
|
|
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
|
|
parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
|
|
return parser.parse_args()
|
|
|
|
if __name__ == "__main__":
|
|
args = main()
|
|
distill_model(
|
|
teacher_model_name=args.teacher_model_name,
|
|
student_model_name=args.student_model_name,
|
|
dataset_name=args.dataset_name,
|
|
config=args.config,
|
|
distill_full_model=args.distill_full_model,
|
|
query_terms=args.query_terms,
|
|
num_epochs=args.num_epochs,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_length,
|
|
learning_rate=args.learning_rate,
|
|
temperature=args.temperature,
|
|
save_path=args.save_path,
|
|
log_dir=args.log_dir,
|
|
checkpoint_dir=args.checkpoint_dir,
|
|
early_stopping_patience=args.early_stopping_patience
|
|
)
|
|
|