lightbulb / distill.py
RobbiePasquale's picture
Update distill.py
a8090dd verified
import argparse
import math
import os
import sys
import json
import jsonlines
import copy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tqdm import tqdm
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======================================
# Import Custom Components from lightbulb_custom
# ======================================
from lightbulb_custom import (
RotaryPositionalEncoding,
MultiHeadAttention,
MoE,
TransformerBlock,
Transformer,
InfoNCE_Loss,
CovarianceRegularization,
DynamicsPerformanceLoss,
ThoughtConsistencyLoss,
PolicyValueJointLoss,
ActionDiversityReward,
ExpectedThoughtValueLoss,
ExplorationRegularization,
KL_DivergenceLoss,
ActionEncoder,
RepresentationNetwork,
DynamicsNetwork,
PredictionNetwork,
ThoughtNode,
MCTS,
State
)
# ==========================
# Custom Dataset Definition
# ==========================
class CustomDataset(Dataset):
def __init__(self, inputs, labels):
self.inputs = inputs
self.labels = labels
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
# ================================
# Utility Functions for Data Loading
# ================================
def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
dataset = load_dataset(dataset_name, config)
if queries:
def filter_func(examples):
return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]]
dataset = dataset.filter(filter_func, batched=True)
return dataset
def load_custom_data_from_files(file_paths):
custom_data = []
for file_path in file_paths:
if file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
if isinstance(data, list):
custom_data.extend(data)
else:
custom_data.append(data)
elif file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
custom_data.extend(reader)
return custom_data
def preprocess_custom_data(data_list):
processed_data = []
for item in data_list:
# Check if the item is a string (JSON)
if isinstance(item, str):
try:
item = json.loads(item)
except json.JSONDecodeError:
print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging
continue # Skip this item if it's not valid JSON
# Process query and content
query = item.get('query', '')
content = item.get('content', '')
if content == "RAG response generation failed.":
content = ""
# Combine query and content
combined_text = f"Query: {query} Content: {content}"
# Process numerical data (assuming these are available in the item dict)
episode_reward = item.get('episode_reward', 0)
loss = item.get('loss', 0)
cosine_similarity = item.get('cosine_similarity', 0)
rag_performance = item.get('rag_performance', 0)
ranking_model_performance = item.get('ranking_model_performance', 0)
# Create a dictionary with processed data
processed_item = {
'text': combined_text,
'episode_reward': episode_reward,
'loss': loss,
'cosine_similarity': cosine_similarity,
'rag_performance': rag_performance,
'ranking_model_performance': ranking_model_performance
}
processed_data.append(processed_item)
return processed_data
def load_custom_data(args, tokenizer, custom_data):
# Preprocess the custom data
processed_data = preprocess_custom_data(custom_data)
# Create a custom dataset
class CustomDatasetProcessed(torch.utils.data.Dataset):
def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
encoded = self.tokenizer.encode_plus(
item['text'],
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoded['input_ids'].squeeze(),
'attention_mask': encoded['attention_mask'].squeeze(),
'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float),
'loss': torch.tensor(item['loss'], dtype=torch.float),
'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float),
'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float),
'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float)
}
# Create dataset and dataloader
dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length)
# Split the dataset into train and eval
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4
)
eval_loader = DataLoader(
eval_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4
)
return train_loader, eval_loader
def prepare_data(tokenizer, dataset, max_length, batch_size):
# Tokenize the inputs and labels
tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
# Create custom dataset
custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
# Split into training and validation sets
train_size = int(0.9 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
# Create DataLoaders
train_loader = DataLoader(
train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
shuffle=False,
batch_size=batch_size,
num_workers=4,
pin_memory=True
)
return train_loader, val_loader
# ==========================
# Training and Validation Functions
# ==========================
def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
"""
Save all models to the specified directory.
Args:
transformer_model (nn.Module): Transformer model.
representation_network (nn.Module): Representation network.
dynamics_network (nn.Module): Dynamics network.
prediction_network (nn.Module): Prediction network.
action_encoder (nn.Module): Action encoder.
save_dir (str): Directory to save the models.
epoch (int): Current epoch number.
"""
os.makedirs(save_dir, exist_ok=True)
torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
print(f"All models saved for epoch {epoch}.")
def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
representation_network.train()
dynamics_network.train()
prediction_network.train()
action_encoder.train()
ppo_agent.policy_network.train()
total_loss = 0.0
optimizer.zero_grad()
print(f"Starting World Model training epoch with {len(train_loader)} batches...")
for i, batch in enumerate(train_loader):
print(f"Processing batch {i+1}/{len(train_loader)}...")
# Move batches to the device
src_batch = batch['input_ids'].to(device)
tgt_batch = batch['labels'].to(device)
with torch.cuda.amp.autocast():
print("Forward pass through Transformer (frozen)...")
with torch.no_grad():
transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
# World Model - Representation
state_representation = representation_network(transformer_output)
# For simplicity, let's assume true actions are provided (e.g., next tokens)
true_actions = tgt_batch[:, :-1]
print(f"True actions shape: {true_actions.shape}")
action_sequences = true_actions
# Get action embeddings
action_embeddings = action_encoder(action_sequences)
print(f"Action embeddings shape: {action_embeddings.shape}")
# Apply dynamics network
predicted_next_state_batch = dynamics_network(state_representation, action_embeddings)
print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}")
# Prediction Network - Policy logits and value
policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
# Define true_policy and true_value as placeholders on the GPU
true_policy = F.one_hot(true_actions, num_classes=input_dim).float()
true_value = torch.zeros_like(value_estimates).to(device)
# Compute individual losses
ppo_loss = ppo_agent.compute_loss(
state_representation,
torch.zeros_like(true_actions, dtype=torch.float32).to(device),
true_actions,
torch.zeros_like(value_estimates, dtype=torch.float32).to(device),
torch.zeros_like(value_estimates, dtype=torch.float32).to(device)
)
info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim),
F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True))
covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch)
perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state)
pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim))
mcts_best_values = torch.zeros(true_actions.size(0)).to(device)
etv = ExpectedThoughtValueLoss()(mcts_best_values)
visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device)
exploration = ExplorationRegularization()(visit_counts)
old_policy = F.softmax(policy_logits.detach(), dim=-1)
new_policy = F.softmax(policy_logits, dim=-1)
kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
# Total Loss
loss = (
ppo_loss +
info_nce +
covariance +
dynamics_loss +
thought_loss +
pv_loss +
action_diversity +
etv +
exploration +
kl_loss
)
loss = loss / args.accumulation_steps
print("Backward pass...")
scaler.scale(loss).backward()
if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
print("Gradient clipping...")
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
[param for group in optimizer.param_groups for param in group['params']],
args.max_grad_norm
)
print("Optimizer step...")
scaler.step(optimizer)
scaler.update()
print("Zeroing gradients...")
optimizer.zero_grad()
print("Updating learning rate...")
scheduler.step()
total_loss += loss.item() * args.accumulation_steps
# Print individual losses and total loss for this batch
print(f"Batch {i+1} completed. Losses:")
print(f" PPO Loss: {ppo_loss.item():.4f}")
print(f" InfoNCE Loss: {info_nce.item():.4f}")
print(f" Covariance Loss: {covariance.item():.4f}")
print(f" Dynamics Loss: {dynamics_loss.item():.4f}")
print(f" Thought Consistency Loss: {thought_loss.item():.4f}")
print(f" Policy-Value Loss: {pv_loss.item():.4f}")
print(f" Action Diversity Loss: {action_diversity.item():.4f}")
print(f" Expected Thought Value Loss: {etv.item():.4f}")
print(f" Exploration Loss: {exploration.item():.4f}")
print(f" KL Divergence Loss: {kl_loss.item():.4f}")
print(f" Total Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
return avg_loss
def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
teacher.eval()
student.train()
total_loss = 0
for batch in tqdm(data_loader, desc="Training"):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
with autocast():
with torch.no_grad():
teacher_outputs = teacher(inputs).logits
teacher_logits = teacher_outputs / temperature
student_outputs = student(inputs).logits
student_logits = student_outputs / temperature
# Compute KL Divergence Loss
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
loss = loss * (temperature ** 2) # Scale loss by temperature squared
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
return avg_loss
def validate(teacher, student, data_loader, criterion, temperature=2.0):
teacher.eval()
student.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(data_loader, desc="Validation"):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
teacher_outputs = teacher(inputs).logits
teacher_logits = teacher_outputs / temperature
student_outputs = student(inputs).logits
student_logits = student_outputs / temperature
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
loss = loss * (temperature ** 2)
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
return avg_loss
def save_checkpoint(state, save_dir, epoch):
os.makedirs(save_dir, exist_ok=True)
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(state, checkpoint_path)
print(f"Checkpoint saved at {checkpoint_path}")
# ==========================
# Inference Functions
# ==========================
def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414):
"""
Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search.
Args:
query (str): The input query or prompt.
world_model_components (tuple): Tuple containing the model components.
root_thought_node (ThoughtNode): The root node of the Tree of Thought.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used.
max_length (int): Maximum length for the generated sequence.
inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought')
beam_size (int): Size of the beam for beam search
n_tokens_predict (int): Number of tokens to predict at each step
mcts_iterations (int): Number of MCTS iterations
exploration_constant (float): Exploration constant for MCTS
Returns:
List[str] or str: The sequence of actions (thoughts) selected or generated text.
"""
if inference_mode != 'world_model':
print("Inference mode other than 'world_model' not implemented yet.")
return ""
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
# Tokenize and encode the query
input_ids = tokenizer.encode(query, return_tensors='pt').to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
# Use the world model components
with torch.no_grad():
transformer_output = model_transformer(input_ids, input_ids)
# Get the initial state representation
initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim)
initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim)
initial_state = State(
representation=initial_representation,
dynamics_network=dynamics_network,
action_encoder=action_encoder,
thought_node=root_thought_node
)
# Use MCTS with Tree of Thought and multi-token beam search
mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant)
current_state = initial_state
thought_sequence = []
for _ in range(max_length // n_tokens_predict):
best_actions = mcts.search_with_beam(current_state)
thought_sequence.extend(best_actions)
# Apply the best actions to get the next state
for action in best_actions:
current_state = current_state.apply_action(action)
# Check if we've reached a leaf node (no further actions)
if len(current_state.thought_node.children) == 0:
break
return thought_sequence
# ==========================
# Main Training Function
# ==========================
def distill_model(
teacher_model_name: str,
student_model_name: str,
dataset_name: str,
config: str,
distill_full_model: bool = True,
query_terms: Optional[List[str]] = None,
num_epochs: int = 3,
batch_size: int = 4,
max_length: int = 128,
learning_rate: float = 5e-5,
temperature: float = 2.0,
save_path: str = "./distilled_model",
log_dir: str = "./logs",
checkpoint_dir: str = "./checkpoints",
early_stopping_patience: int = 3,
accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
weight_decay: float = 0.01
):
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded successfully.")
# Load teacher model
print("Loading teacher model...")
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
print("Teacher model loaded successfully.")
if distill_full_model:
# Full World Model Distillation
print(f"Starting Full World Model Distillation into '{student_model_name}'.")
# Load or instantiate student model
print(f"Attempting to load student model '{student_model_name}'...")
try:
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
print(f"Student model '{student_model_name}' loaded successfully.")
except (OSError, ValueError) as e:
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
# Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
try:
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
# Save the instantiated student model with the desired name
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
except Exception as inst_e:
print(f"Failed to instantiate and save student model: {inst_e}")
sys.exit(1)
# Optionally freeze teacher model parameters
for param in teacher.parameters():
param.requires_grad = False
# Load and prepare dataset
print(f"Loading full dataset '{dataset_name}' with config '{config}'...")
dataset = load_dataset(dataset_name, config)
train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
print("Data loaded and preprocessed successfully.")
# Define optimizer, scheduler, and scaler for mixed precision
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = GradScaler()
# Define loss criterion
criterion = nn.KLDivLoss(reduction="batchmean")
best_val_loss = float('inf')
epochs_no_improve = 0
# Training loop
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
print("-" * 20)
# Training
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
print(f"Training Loss: {train_loss:.4f}")
writer.add_scalar("Loss/Train", train_loss, epoch)
# Validation
val_loss = validate(teacher, student, val_loader, criterion, temperature)
print(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar("Loss/Validation", val_loss, epoch)
# Check for improvement
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
# Save the best model
save_checkpoint({
'epoch': epoch,
'model_state_dict': student.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}, checkpoint_dir, epoch)
# Save the model as the best one
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Best model saved at epoch {epoch}")
else:
epochs_no_improve += 1
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
if epochs_no_improve >= early_stopping_patience:
print("Early stopping triggered")
break
# Step the scheduler
scheduler.step()
writer.close()
print("\nFull World Model Distillation completed.")
else:
# Standard Language Model Distillation
print(f"Starting Standard Language Model Distillation into '{student_model_name}'.")
if not query_terms:
print("Error: --query_terms must be provided for standard language model distillation.")
sys.exit(1)
# Load or instantiate student model
print(f"Attempting to load student model '{student_model_name}'...")
try:
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
print(f"Student model '{student_model_name}' loaded successfully.")
except (OSError, ValueError) as e:
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
# Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
try:
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
# Save the instantiated student model with the desired name
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
except Exception as inst_e:
print(f"Failed to instantiate and save student model: {inst_e}")
sys.exit(1)
# Optionally freeze teacher model parameters
for param in teacher.parameters():
param.requires_grad = False
# Load and prepare custom dataset
print(f"Loading custom data files: {query_terms}")
custom_data = load_custom_data_from_files(query_terms)
train_loader, val_loader = load_custom_data(
args=argparse.Namespace(max_length=max_length),
tokenizer=tokenizer,
custom_data=custom_data
)
print("Custom data loaded and preprocessed successfully.")
# Define optimizer, scheduler, and scaler for mixed precision
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = GradScaler()
# Define loss criterion
criterion = nn.KLDivLoss(reduction="batchmean")
best_val_loss = float('inf')
epochs_no_improve = 0
# Training loop
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
print("-" * 20)
# Training
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
print(f"Training Loss: {train_loss:.4f}")
writer.add_scalar("Loss/Train", train_loss, epoch)
# Validation
val_loss = validate(teacher, student, val_loader, criterion, temperature)
print(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar("Loss/Validation", val_loss, epoch)
# Check for improvement
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
# Save the best model
save_checkpoint({
'epoch': epoch,
'model_state_dict': student.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}, checkpoint_dir, epoch)
# Save the model as the best one
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Best model saved at epoch {epoch}")
else:
epochs_no_improve += 1
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
if epochs_no_improve >= early_stopping_patience:
print("Early stopping triggered")
break
# Step the scheduler
scheduler.step()
writer.close()
print("\nStandard Language Model Distillation completed.")
# ==========================
# Argument Parsing
# ==========================
def parse_args():
parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.")
# Required arguments
parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
# Dataset arguments
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
# Mode selection
parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model")
# For standard distillation
parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation")
# Training hyperparameters
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
# Saving and logging
parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
# Early stopping
parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
# Gradient accumulation and optimization
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer")
return parser.parse_args()
# ==========================
# Main Function
# ==========================
def main():
args = parse_args()
print("Arguments parsed successfully.")
# Create save directories
os.makedirs(args.save_path, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)
print(f"Save directory created: {args.save_path}")
print(f"Log directory created: {args.log_dir}")
print(f"Checkpoint directory created: {args.checkpoint_dir}")
# Handle dataset loading based on distillation mode
if args.distill_full_model:
# Full World Model Distillation
distill_model(
teacher_model_name=args.teacher_model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.config,
distill_full_model=args.distill_full_model,
query_terms=args.query_terms, # Not used in this mode
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience,
accumulation_steps=args.accumulation_steps,
max_grad_norm=args.max_grad_norm,
weight_decay=args.weight_decay
)
else:
# Standard Language Model Distillation
distill_model(
teacher_model_name=args.teacher_model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.config,
distill_full_model=args.distill_full_model,
query_terms=args.query_terms,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience,
accumulation_steps=args.accumulation_steps,
max_grad_norm=args.max_grad_norm,
weight_decay=args.weight_decay
)
if __name__ == "__main__":
main()