|
import argparse |
|
import math |
|
import os |
|
import sys |
|
import json |
|
import jsonlines |
|
import copy |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset, random_split |
|
from torch.cuda.amp import autocast, GradScaler |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from datasets import load_dataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
from tqdm import tqdm |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
from lightbulb_custom import ( |
|
RotaryPositionalEncoding, |
|
MultiHeadAttention, |
|
MoE, |
|
TransformerBlock, |
|
Transformer, |
|
InfoNCE_Loss, |
|
CovarianceRegularization, |
|
DynamicsPerformanceLoss, |
|
ThoughtConsistencyLoss, |
|
PolicyValueJointLoss, |
|
ActionDiversityReward, |
|
ExpectedThoughtValueLoss, |
|
ExplorationRegularization, |
|
KL_DivergenceLoss, |
|
ActionEncoder, |
|
RepresentationNetwork, |
|
DynamicsNetwork, |
|
PredictionNetwork, |
|
ThoughtNode, |
|
MCTS, |
|
State |
|
) |
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, inputs, labels): |
|
self.inputs = inputs |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.inputs) |
|
|
|
def __getitem__(self, idx): |
|
return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]} |
|
|
|
|
|
|
|
|
|
def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None): |
|
dataset = load_dataset(dataset_name, config) |
|
if queries: |
|
def filter_func(examples): |
|
return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]] |
|
dataset = dataset.filter(filter_func, batched=True) |
|
return dataset |
|
|
|
def load_custom_data_from_files(file_paths): |
|
custom_data = [] |
|
for file_path in file_paths: |
|
if file_path.endswith('.json'): |
|
with open(file_path, 'r') as f: |
|
data = json.load(f) |
|
if isinstance(data, list): |
|
custom_data.extend(data) |
|
else: |
|
custom_data.append(data) |
|
elif file_path.endswith('.jsonl'): |
|
with jsonlines.open(file_path) as reader: |
|
custom_data.extend(reader) |
|
return custom_data |
|
|
|
def preprocess_custom_data(data_list): |
|
processed_data = [] |
|
for item in data_list: |
|
|
|
if isinstance(item, str): |
|
try: |
|
item = json.loads(item) |
|
except json.JSONDecodeError: |
|
print(f"Failed to parse JSON: {item[:100]}...") |
|
continue |
|
|
|
|
|
query = item.get('query', '') |
|
content = item.get('content', '') |
|
if content == "RAG response generation failed.": |
|
content = "" |
|
|
|
|
|
combined_text = f"Query: {query} Content: {content}" |
|
|
|
|
|
episode_reward = item.get('episode_reward', 0) |
|
loss = item.get('loss', 0) |
|
cosine_similarity = item.get('cosine_similarity', 0) |
|
rag_performance = item.get('rag_performance', 0) |
|
ranking_model_performance = item.get('ranking_model_performance', 0) |
|
|
|
|
|
processed_item = { |
|
'text': combined_text, |
|
'episode_reward': episode_reward, |
|
'loss': loss, |
|
'cosine_similarity': cosine_similarity, |
|
'rag_performance': rag_performance, |
|
'ranking_model_performance': ranking_model_performance |
|
} |
|
|
|
processed_data.append(processed_item) |
|
|
|
return processed_data |
|
|
|
def load_custom_data(args, tokenizer, custom_data): |
|
|
|
processed_data = preprocess_custom_data(custom_data) |
|
|
|
|
|
class CustomDatasetProcessed(torch.utils.data.Dataset): |
|
def __init__(self, data, tokenizer, max_length): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
encoded = self.tokenizer.encode_plus( |
|
item['text'], |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
return { |
|
'input_ids': encoded['input_ids'].squeeze(), |
|
'attention_mask': encoded['attention_mask'].squeeze(), |
|
'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float), |
|
'loss': torch.tensor(item['loss'], dtype=torch.float), |
|
'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float), |
|
'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float), |
|
'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float) |
|
} |
|
|
|
|
|
dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length) |
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
eval_size = len(dataset) - train_size |
|
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size]) |
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=4 |
|
) |
|
eval_loader = DataLoader( |
|
eval_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=4 |
|
) |
|
|
|
return train_loader, eval_loader |
|
|
|
def prepare_data(tokenizer, dataset, max_length, batch_size): |
|
|
|
tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) |
|
tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) |
|
|
|
|
|
custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"]) |
|
|
|
|
|
train_size = int(0.9 * len(custom_dataset)) |
|
val_size = len(custom_dataset) - train_size |
|
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size]) |
|
|
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
shuffle=True, |
|
batch_size=batch_size, |
|
num_workers=4, |
|
pin_memory=True |
|
) |
|
val_loader = DataLoader( |
|
val_dataset, |
|
shuffle=False, |
|
batch_size=batch_size, |
|
num_workers=4, |
|
pin_memory=True |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
|
|
|
def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch): |
|
""" |
|
Save all models to the specified directory. |
|
Args: |
|
transformer_model (nn.Module): Transformer model. |
|
representation_network (nn.Module): Representation network. |
|
dynamics_network (nn.Module): Dynamics network. |
|
prediction_network (nn.Module): Prediction network. |
|
action_encoder (nn.Module): Action encoder. |
|
save_dir (str): Directory to save the models. |
|
epoch (int): Current epoch number. |
|
""" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt')) |
|
torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt')) |
|
torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt')) |
|
torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt')) |
|
torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt')) |
|
|
|
print(f"All models saved for epoch {epoch}.") |
|
|
|
def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
|
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
|
representation_network.train() |
|
dynamics_network.train() |
|
prediction_network.train() |
|
action_encoder.train() |
|
ppo_agent.policy_network.train() |
|
|
|
total_loss = 0.0 |
|
optimizer.zero_grad() |
|
print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
|
|
|
for i, batch in enumerate(train_loader): |
|
print(f"Processing batch {i+1}/{len(train_loader)}...") |
|
|
|
|
|
src_batch = batch['input_ids'].to(device) |
|
tgt_batch = batch['labels'].to(device) |
|
|
|
with torch.cuda.amp.autocast(): |
|
print("Forward pass through Transformer (frozen)...") |
|
with torch.no_grad(): |
|
transformer_output = model_transformer(src_batch, tgt_batch[:, :-1]) |
|
|
|
|
|
state_representation = representation_network(transformer_output) |
|
|
|
|
|
true_actions = tgt_batch[:, :-1] |
|
print(f"True actions shape: {true_actions.shape}") |
|
action_sequences = true_actions |
|
|
|
|
|
action_embeddings = action_encoder(action_sequences) |
|
print(f"Action embeddings shape: {action_embeddings.shape}") |
|
|
|
|
|
predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
|
print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}") |
|
|
|
|
|
policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
|
|
|
|
|
true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
|
true_value = torch.zeros_like(value_estimates).to(device) |
|
|
|
|
|
ppo_loss = ppo_agent.compute_loss( |
|
state_representation, |
|
torch.zeros_like(true_actions, dtype=torch.float32).to(device), |
|
true_actions, |
|
torch.zeros_like(value_estimates, dtype=torch.float32).to(device), |
|
torch.zeros_like(value_estimates, dtype=torch.float32).to(device) |
|
) |
|
|
|
info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim), |
|
F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True)) |
|
|
|
covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
|
dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
|
|
|
perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
|
thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
|
|
|
pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
|
action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
|
|
|
mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
|
etv = ExpectedThoughtValueLoss()(mcts_best_values) |
|
|
|
visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
|
exploration = ExplorationRegularization()(visit_counts) |
|
|
|
old_policy = F.softmax(policy_logits.detach(), dim=-1) |
|
new_policy = F.softmax(policy_logits, dim=-1) |
|
kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
|
|
|
|
|
loss = ( |
|
ppo_loss + |
|
info_nce + |
|
covariance + |
|
dynamics_loss + |
|
thought_loss + |
|
pv_loss + |
|
action_diversity + |
|
etv + |
|
exploration + |
|
kl_loss |
|
) |
|
loss = loss / args.accumulation_steps |
|
|
|
print("Backward pass...") |
|
scaler.scale(loss).backward() |
|
|
|
if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
|
print("Gradient clipping...") |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_( |
|
[param for group in optimizer.param_groups for param in group['params']], |
|
args.max_grad_norm |
|
) |
|
|
|
print("Optimizer step...") |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
print("Zeroing gradients...") |
|
optimizer.zero_grad() |
|
|
|
print("Updating learning rate...") |
|
scheduler.step() |
|
|
|
total_loss += loss.item() * args.accumulation_steps |
|
|
|
|
|
print(f"Batch {i+1} completed. Losses:") |
|
print(f" PPO Loss: {ppo_loss.item():.4f}") |
|
print(f" InfoNCE Loss: {info_nce.item():.4f}") |
|
print(f" Covariance Loss: {covariance.item():.4f}") |
|
print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
|
print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
|
print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
|
print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
|
print(f" Expected Thought Value Loss: {etv.item():.4f}") |
|
print(f" Exploration Loss: {exploration.item():.4f}") |
|
print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
|
print(f" Total Loss: {loss.item():.4f}") |
|
|
|
avg_loss = total_loss / len(train_loader) |
|
print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
|
return avg_loss |
|
|
|
def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0): |
|
teacher.eval() |
|
student.train() |
|
total_loss = 0 |
|
|
|
for batch in tqdm(data_loader, desc="Training"): |
|
inputs = batch["input_ids"].to(device) |
|
labels = batch["labels"].to(device) |
|
|
|
with autocast(): |
|
with torch.no_grad(): |
|
teacher_outputs = teacher(inputs).logits |
|
teacher_logits = teacher_outputs / temperature |
|
|
|
student_outputs = student(inputs).logits |
|
student_logits = student_outputs / temperature |
|
|
|
|
|
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) |
|
loss = loss * (temperature ** 2) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(data_loader) |
|
return avg_loss |
|
|
|
def validate(teacher, student, data_loader, criterion, temperature=2.0): |
|
teacher.eval() |
|
student.eval() |
|
total_loss = 0 |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(data_loader, desc="Validation"): |
|
inputs = batch["input_ids"].to(device) |
|
labels = batch["labels"].to(device) |
|
|
|
teacher_outputs = teacher(inputs).logits |
|
teacher_logits = teacher_outputs / temperature |
|
|
|
student_outputs = student(inputs).logits |
|
student_logits = student_outputs / temperature |
|
|
|
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) |
|
loss = loss * (temperature ** 2) |
|
|
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(data_loader) |
|
return avg_loss |
|
|
|
def save_checkpoint(state, save_dir, epoch): |
|
os.makedirs(save_dir, exist_ok=True) |
|
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt') |
|
torch.save(state, checkpoint_path) |
|
print(f"Checkpoint saved at {checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414): |
|
""" |
|
Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search. |
|
Args: |
|
query (str): The input query or prompt. |
|
world_model_components (tuple): Tuple containing the model components. |
|
root_thought_node (ThoughtNode): The root node of the Tree of Thought. |
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used. |
|
max_length (int): Maximum length for the generated sequence. |
|
inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought') |
|
beam_size (int): Size of the beam for beam search |
|
n_tokens_predict (int): Number of tokens to predict at each step |
|
mcts_iterations (int): Number of MCTS iterations |
|
exploration_constant (float): Exploration constant for MCTS |
|
Returns: |
|
List[str] or str: The sequence of actions (thoughts) selected or generated text. |
|
""" |
|
if inference_mode != 'world_model': |
|
print("Inference mode other than 'world_model' not implemented yet.") |
|
return "" |
|
|
|
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
|
|
|
|
|
input_ids = tokenizer.encode(query, return_tensors='pt').to(device) |
|
attention_mask = (input_ids != tokenizer.pad_token_id).long() |
|
|
|
|
|
with torch.no_grad(): |
|
transformer_output = model_transformer(input_ids, input_ids) |
|
|
|
initial_representation = representation_network(transformer_output) |
|
initial_representation = initial_representation[:, -1, :].unsqueeze(1) |
|
initial_state = State( |
|
representation=initial_representation, |
|
dynamics_network=dynamics_network, |
|
action_encoder=action_encoder, |
|
thought_node=root_thought_node |
|
) |
|
|
|
mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant) |
|
|
|
current_state = initial_state |
|
thought_sequence = [] |
|
|
|
for _ in range(max_length // n_tokens_predict): |
|
best_actions = mcts.search_with_beam(current_state) |
|
|
|
thought_sequence.extend(best_actions) |
|
|
|
|
|
for action in best_actions: |
|
current_state = current_state.apply_action(action) |
|
|
|
|
|
if len(current_state.thought_node.children) == 0: |
|
break |
|
|
|
return thought_sequence |
|
|
|
|
|
|
|
|
|
|
|
def distill_model( |
|
teacher_model_name: str, |
|
student_model_name: str, |
|
dataset_name: str, |
|
config: str, |
|
distill_full_model: bool = True, |
|
query_terms: Optional[List[str]] = None, |
|
num_epochs: int = 3, |
|
batch_size: int = 4, |
|
max_length: int = 128, |
|
learning_rate: float = 5e-5, |
|
temperature: float = 2.0, |
|
save_path: str = "./distilled_model", |
|
log_dir: str = "./logs", |
|
checkpoint_dir: str = "./checkpoints", |
|
early_stopping_patience: int = 3, |
|
accumulation_steps: int = 1, |
|
max_grad_norm: float = 1.0, |
|
weight_decay: float = 0.01 |
|
): |
|
|
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
|
|
|
print("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print("Tokenizer loaded successfully.") |
|
|
|
|
|
print("Loading teacher model...") |
|
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device) |
|
print("Teacher model loaded successfully.") |
|
|
|
if distill_full_model: |
|
|
|
print(f"Starting Full World Model Distillation into '{student_model_name}'.") |
|
|
|
|
|
print(f"Attempting to load student model '{student_model_name}'...") |
|
try: |
|
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) |
|
print(f"Student model '{student_model_name}' loaded successfully.") |
|
except (OSError, ValueError) as e: |
|
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") |
|
|
|
try: |
|
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) |
|
|
|
student.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") |
|
except Exception as inst_e: |
|
print(f"Failed to instantiate and save student model: {inst_e}") |
|
sys.exit(1) |
|
|
|
|
|
for param in teacher.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
print(f"Loading full dataset '{dataset_name}' with config '{config}'...") |
|
dataset = load_dataset(dataset_name, config) |
|
train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size) |
|
print("Data loaded and preprocessed successfully.") |
|
|
|
|
|
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) |
|
scaler = GradScaler() |
|
|
|
|
|
criterion = nn.KLDivLoss(reduction="batchmean") |
|
|
|
best_val_loss = float('inf') |
|
epochs_no_improve = 0 |
|
|
|
|
|
for epoch in range(1, num_epochs + 1): |
|
print(f"\nEpoch {epoch}/{num_epochs}") |
|
print("-" * 20) |
|
|
|
|
|
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) |
|
print(f"Training Loss: {train_loss:.4f}") |
|
writer.add_scalar("Loss/Train", train_loss, epoch) |
|
|
|
|
|
val_loss = validate(teacher, student, val_loader, criterion, temperature) |
|
print(f"Validation Loss: {val_loss:.4f}") |
|
writer.add_scalar("Loss/Validation", val_loss, epoch) |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
|
|
save_checkpoint({ |
|
'epoch': epoch, |
|
'model_state_dict': student.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'scaler_state_dict': scaler.state_dict(), |
|
'best_val_loss': best_val_loss |
|
}, checkpoint_dir, epoch) |
|
|
|
student.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
print(f"Best model saved at epoch {epoch}") |
|
else: |
|
epochs_no_improve += 1 |
|
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") |
|
if epochs_no_improve >= early_stopping_patience: |
|
print("Early stopping triggered") |
|
break |
|
|
|
|
|
scheduler.step() |
|
|
|
writer.close() |
|
print("\nFull World Model Distillation completed.") |
|
|
|
else: |
|
|
|
print(f"Starting Standard Language Model Distillation into '{student_model_name}'.") |
|
|
|
if not query_terms: |
|
print("Error: --query_terms must be provided for standard language model distillation.") |
|
sys.exit(1) |
|
|
|
|
|
print(f"Attempting to load student model '{student_model_name}'...") |
|
try: |
|
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) |
|
print(f"Student model '{student_model_name}' loaded successfully.") |
|
except (OSError, ValueError) as e: |
|
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") |
|
|
|
try: |
|
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) |
|
|
|
student.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") |
|
except Exception as inst_e: |
|
print(f"Failed to instantiate and save student model: {inst_e}") |
|
sys.exit(1) |
|
|
|
|
|
for param in teacher.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
print(f"Loading custom data files: {query_terms}") |
|
custom_data = load_custom_data_from_files(query_terms) |
|
train_loader, val_loader = load_custom_data( |
|
args=argparse.Namespace(max_length=max_length), |
|
tokenizer=tokenizer, |
|
custom_data=custom_data |
|
) |
|
print("Custom data loaded and preprocessed successfully.") |
|
|
|
|
|
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) |
|
scaler = GradScaler() |
|
|
|
|
|
criterion = nn.KLDivLoss(reduction="batchmean") |
|
|
|
best_val_loss = float('inf') |
|
epochs_no_improve = 0 |
|
|
|
|
|
for epoch in range(1, num_epochs + 1): |
|
print(f"\nEpoch {epoch}/{num_epochs}") |
|
print("-" * 20) |
|
|
|
|
|
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) |
|
print(f"Training Loss: {train_loss:.4f}") |
|
writer.add_scalar("Loss/Train", train_loss, epoch) |
|
|
|
|
|
val_loss = validate(teacher, student, val_loader, criterion, temperature) |
|
print(f"Validation Loss: {val_loss:.4f}") |
|
writer.add_scalar("Loss/Validation", val_loss, epoch) |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
|
|
save_checkpoint({ |
|
'epoch': epoch, |
|
'model_state_dict': student.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'scaler_state_dict': scaler.state_dict(), |
|
'best_val_loss': best_val_loss |
|
}, checkpoint_dir, epoch) |
|
|
|
student.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
print(f"Best model saved at epoch {epoch}") |
|
else: |
|
epochs_no_improve += 1 |
|
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") |
|
if epochs_no_improve >= early_stopping_patience: |
|
print("Early stopping triggered") |
|
break |
|
|
|
|
|
scheduler.step() |
|
|
|
writer.close() |
|
print("\nStandard Language Model Distillation completed.") |
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.") |
|
|
|
|
|
parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model") |
|
parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model") |
|
|
|
|
|
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset") |
|
parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')") |
|
|
|
|
|
parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model") |
|
|
|
|
|
parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation") |
|
|
|
|
|
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs") |
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size") |
|
parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length") |
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") |
|
parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature") |
|
|
|
|
|
parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model") |
|
parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs") |
|
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints") |
|
|
|
|
|
parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience") |
|
|
|
|
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") |
|
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping") |
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
print("Arguments parsed successfully.") |
|
|
|
|
|
os.makedirs(args.save_path, exist_ok=True) |
|
os.makedirs(args.log_dir, exist_ok=True) |
|
os.makedirs(args.checkpoint_dir, exist_ok=True) |
|
print(f"Save directory created: {args.save_path}") |
|
print(f"Log directory created: {args.log_dir}") |
|
print(f"Checkpoint directory created: {args.checkpoint_dir}") |
|
|
|
|
|
if args.distill_full_model: |
|
|
|
distill_model( |
|
teacher_model_name=args.teacher_model_name, |
|
student_model_name=args.student_model_name, |
|
dataset_name=args.dataset_name, |
|
config=args.config, |
|
distill_full_model=args.distill_full_model, |
|
query_terms=args.query_terms, |
|
num_epochs=args.num_epochs, |
|
batch_size=args.batch_size, |
|
max_length=args.max_length, |
|
learning_rate=args.learning_rate, |
|
temperature=args.temperature, |
|
save_path=args.save_path, |
|
log_dir=args.log_dir, |
|
checkpoint_dir=args.checkpoint_dir, |
|
early_stopping_patience=args.early_stopping_patience, |
|
accumulation_steps=args.accumulation_steps, |
|
max_grad_norm=args.max_grad_norm, |
|
weight_decay=args.weight_decay |
|
) |
|
else: |
|
|
|
distill_model( |
|
teacher_model_name=args.teacher_model_name, |
|
student_model_name=args.student_model_name, |
|
dataset_name=args.dataset_name, |
|
config=args.config, |
|
distill_full_model=args.distill_full_model, |
|
query_terms=args.query_terms, |
|
num_epochs=args.num_epochs, |
|
batch_size=args.batch_size, |
|
max_length=args.max_length, |
|
learning_rate=args.learning_rate, |
|
temperature=args.temperature, |
|
save_path=args.save_path, |
|
log_dir=args.log_dir, |
|
checkpoint_dir=args.checkpoint_dir, |
|
early_stopping_patience=args.early_stopping_patience, |
|
accumulation_steps=args.accumulation_steps, |
|
max_grad_norm=args.max_grad_norm, |
|
weight_decay=args.weight_decay |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|