import argparse import math import os import sys import json import jsonlines import copy from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset, random_split from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from tqdm import tqdm # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ====================================== # Import Custom Components from lightbulb_custom # ====================================== from lightbulb_custom import ( RotaryPositionalEncoding, MultiHeadAttention, MoE, TransformerBlock, Transformer, InfoNCE_Loss, CovarianceRegularization, DynamicsPerformanceLoss, ThoughtConsistencyLoss, PolicyValueJointLoss, ActionDiversityReward, ExpectedThoughtValueLoss, ExplorationRegularization, KL_DivergenceLoss, ActionEncoder, RepresentationNetwork, DynamicsNetwork, PredictionNetwork, ThoughtNode, MCTS, State ) # ========================== # Custom Dataset Definition # ========================== class CustomDataset(Dataset): def __init__(self, inputs, labels): self.inputs = inputs self.labels = labels def __len__(self): return len(self.inputs) def __getitem__(self, idx): return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]} # ================================ # Utility Functions for Data Loading # ================================ def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None): dataset = load_dataset(dataset_name, config) if queries: def filter_func(examples): return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]] dataset = dataset.filter(filter_func, batched=True) return dataset def load_custom_data_from_files(file_paths): custom_data = [] for file_path in file_paths: if file_path.endswith('.json'): with open(file_path, 'r') as f: data = json.load(f) if isinstance(data, list): custom_data.extend(data) else: custom_data.append(data) elif file_path.endswith('.jsonl'): with jsonlines.open(file_path) as reader: custom_data.extend(reader) return custom_data def preprocess_custom_data(data_list): processed_data = [] for item in data_list: # Check if the item is a string (JSON) if isinstance(item, str): try: item = json.loads(item) except json.JSONDecodeError: print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging continue # Skip this item if it's not valid JSON # Process query and content query = item.get('query', '') content = item.get('content', '') if content == "RAG response generation failed.": content = "" # Combine query and content combined_text = f"Query: {query} Content: {content}" # Process numerical data (assuming these are available in the item dict) episode_reward = item.get('episode_reward', 0) loss = item.get('loss', 0) cosine_similarity = item.get('cosine_similarity', 0) rag_performance = item.get('rag_performance', 0) ranking_model_performance = item.get('ranking_model_performance', 0) # Create a dictionary with processed data processed_item = { 'text': combined_text, 'episode_reward': episode_reward, 'loss': loss, 'cosine_similarity': cosine_similarity, 'rag_performance': rag_performance, 'ranking_model_performance': ranking_model_performance } processed_data.append(processed_item) return processed_data def load_custom_data(args, tokenizer, custom_data): # Preprocess the custom data processed_data = preprocess_custom_data(custom_data) # Create a custom dataset class CustomDatasetProcessed(torch.utils.data.Dataset): def __init__(self, data, tokenizer, max_length): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] encoded = self.tokenizer.encode_plus( item['text'], max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoded['input_ids'].squeeze(), 'attention_mask': encoded['attention_mask'].squeeze(), 'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float), 'loss': torch.tensor(item['loss'], dtype=torch.float), 'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float), 'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float), 'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float) } # Create dataset and dataloader dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length) # Split the dataset into train and eval train_size = int(0.8 * len(dataset)) eval_size = len(dataset) - train_size train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size]) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4 ) eval_loader = DataLoader( eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4 ) return train_loader, eval_loader def prepare_data(tokenizer, dataset, max_length, batch_size): # Tokenize the inputs and labels tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) # Create custom dataset custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"]) # Split into training and validation sets train_size = int(0.9 * len(custom_dataset)) val_size = len(custom_dataset) - train_size train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size]) # Create DataLoaders train_loader = DataLoader( train_dataset, shuffle=True, batch_size=batch_size, num_workers=4, pin_memory=True ) val_loader = DataLoader( val_dataset, shuffle=False, batch_size=batch_size, num_workers=4, pin_memory=True ) return train_loader, val_loader # ========================== # Training and Validation Functions # ========================== def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch): """ Save all models to the specified directory. Args: transformer_model (nn.Module): Transformer model. representation_network (nn.Module): Representation network. dynamics_network (nn.Module): Dynamics network. prediction_network (nn.Module): Prediction network. action_encoder (nn.Module): Action encoder. save_dir (str): Directory to save the models. epoch (int): Current epoch number. """ os.makedirs(save_dir, exist_ok=True) torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt')) torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt')) torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt')) torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt')) torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt')) print(f"All models saved for epoch {epoch}.") def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components representation_network.train() dynamics_network.train() prediction_network.train() action_encoder.train() ppo_agent.policy_network.train() total_loss = 0.0 optimizer.zero_grad() print(f"Starting World Model training epoch with {len(train_loader)} batches...") for i, batch in enumerate(train_loader): print(f"Processing batch {i+1}/{len(train_loader)}...") # Move batches to the device src_batch = batch['input_ids'].to(device) tgt_batch = batch['labels'].to(device) with torch.cuda.amp.autocast(): print("Forward pass through Transformer (frozen)...") with torch.no_grad(): transformer_output = model_transformer(src_batch, tgt_batch[:, :-1]) # World Model - Representation state_representation = representation_network(transformer_output) # For simplicity, let's assume true actions are provided (e.g., next tokens) true_actions = tgt_batch[:, :-1] print(f"True actions shape: {true_actions.shape}") action_sequences = true_actions # Get action embeddings action_embeddings = action_encoder(action_sequences) print(f"Action embeddings shape: {action_embeddings.shape}") # Apply dynamics network predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}") # Prediction Network - Policy logits and value policy_logits, value_estimates = prediction_network(predicted_next_state_batch) # Define true_policy and true_value as placeholders on the GPU true_policy = F.one_hot(true_actions, num_classes=input_dim).float() true_value = torch.zeros_like(value_estimates).to(device) # Compute individual losses ppo_loss = ppo_agent.compute_loss( state_representation, torch.zeros_like(true_actions, dtype=torch.float32).to(device), true_actions, torch.zeros_like(value_estimates, dtype=torch.float32).to(device), torch.zeros_like(value_estimates, dtype=torch.float32).to(device) ) info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim), F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True)) covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) mcts_best_values = torch.zeros(true_actions.size(0)).to(device) etv = ExpectedThoughtValueLoss()(mcts_best_values) visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) exploration = ExplorationRegularization()(visit_counts) old_policy = F.softmax(policy_logits.detach(), dim=-1) new_policy = F.softmax(policy_logits, dim=-1) kl_loss = KL_DivergenceLoss()(old_policy, new_policy) # Total Loss loss = ( ppo_loss + info_nce + covariance + dynamics_loss + thought_loss + pv_loss + action_diversity + etv + exploration + kl_loss ) loss = loss / args.accumulation_steps print("Backward pass...") scaler.scale(loss).backward() if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): print("Gradient clipping...") scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( [param for group in optimizer.param_groups for param in group['params']], args.max_grad_norm ) print("Optimizer step...") scaler.step(optimizer) scaler.update() print("Zeroing gradients...") optimizer.zero_grad() print("Updating learning rate...") scheduler.step() total_loss += loss.item() * args.accumulation_steps # Print individual losses and total loss for this batch print(f"Batch {i+1} completed. Losses:") print(f" PPO Loss: {ppo_loss.item():.4f}") print(f" InfoNCE Loss: {info_nce.item():.4f}") print(f" Covariance Loss: {covariance.item():.4f}") print(f" Dynamics Loss: {dynamics_loss.item():.4f}") print(f" Thought Consistency Loss: {thought_loss.item():.4f}") print(f" Policy-Value Loss: {pv_loss.item():.4f}") print(f" Action Diversity Loss: {action_diversity.item():.4f}") print(f" Expected Thought Value Loss: {etv.item():.4f}") print(f" Exploration Loss: {exploration.item():.4f}") print(f" KL Divergence Loss: {kl_loss.item():.4f}") print(f" Total Loss: {loss.item():.4f}") avg_loss = total_loss / len(train_loader) print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") return avg_loss def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0): teacher.eval() student.train() total_loss = 0 for batch in tqdm(data_loader, desc="Training"): inputs = batch["input_ids"].to(device) labels = batch["labels"].to(device) with autocast(): with torch.no_grad(): teacher_outputs = teacher(inputs).logits teacher_logits = teacher_outputs / temperature student_outputs = student(inputs).logits student_logits = student_outputs / temperature # Compute KL Divergence Loss loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) loss = loss * (temperature ** 2) # Scale loss by temperature squared scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() avg_loss = total_loss / len(data_loader) return avg_loss def validate(teacher, student, data_loader, criterion, temperature=2.0): teacher.eval() student.eval() total_loss = 0 with torch.no_grad(): for batch in tqdm(data_loader, desc="Validation"): inputs = batch["input_ids"].to(device) labels = batch["labels"].to(device) teacher_outputs = teacher(inputs).logits teacher_logits = teacher_outputs / temperature student_outputs = student(inputs).logits student_logits = student_outputs / temperature loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) loss = loss * (temperature ** 2) total_loss += loss.item() avg_loss = total_loss / len(data_loader) return avg_loss def save_checkpoint(state, save_dir, epoch): os.makedirs(save_dir, exist_ok=True) checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt') torch.save(state, checkpoint_path) print(f"Checkpoint saved at {checkpoint_path}") # ========================== # Inference Functions # ========================== def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414): """ Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search. Args: query (str): The input query or prompt. world_model_components (tuple): Tuple containing the model components. root_thought_node (ThoughtNode): The root node of the Tree of Thought. tokenizer (transformers.PreTrainedTokenizer): The tokenizer used. max_length (int): Maximum length for the generated sequence. inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought') beam_size (int): Size of the beam for beam search n_tokens_predict (int): Number of tokens to predict at each step mcts_iterations (int): Number of MCTS iterations exploration_constant (float): Exploration constant for MCTS Returns: List[str] or str: The sequence of actions (thoughts) selected or generated text. """ if inference_mode != 'world_model': print("Inference mode other than 'world_model' not implemented yet.") return "" representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components # Tokenize and encode the query input_ids = tokenizer.encode(query, return_tensors='pt').to(device) attention_mask = (input_ids != tokenizer.pad_token_id).long() # Use the world model components with torch.no_grad(): transformer_output = model_transformer(input_ids, input_ids) # Get the initial state representation initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim) initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim) initial_state = State( representation=initial_representation, dynamics_network=dynamics_network, action_encoder=action_encoder, thought_node=root_thought_node ) # Use MCTS with Tree of Thought and multi-token beam search mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant) current_state = initial_state thought_sequence = [] for _ in range(max_length // n_tokens_predict): best_actions = mcts.search_with_beam(current_state) thought_sequence.extend(best_actions) # Apply the best actions to get the next state for action in best_actions: current_state = current_state.apply_action(action) # Check if we've reached a leaf node (no further actions) if len(current_state.thought_node.children) == 0: break return thought_sequence # ========================== # Main Training Function # ========================== def distill_model( teacher_model_name: str, student_model_name: str, dataset_name: str, config: str, distill_full_model: bool = True, query_terms: Optional[List[str]] = None, num_epochs: int = 3, batch_size: int = 4, max_length: int = 128, learning_rate: float = 5e-5, temperature: float = 2.0, save_path: str = "./distilled_model", log_dir: str = "./logs", checkpoint_dir: str = "./checkpoints", early_stopping_patience: int = 3, accumulation_steps: int = 1, max_grad_norm: float = 1.0, weight_decay: float = 0.01 ): # Initialize TensorBoard writer writer = SummaryWriter(log_dir=log_dir) # Load tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Tokenizer loaded successfully.") # Load teacher model print("Loading teacher model...") teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device) print("Teacher model loaded successfully.") if distill_full_model: # Full World Model Distillation print(f"Starting Full World Model Distillation into '{student_model_name}'.") # Load or instantiate student model print(f"Attempting to load student model '{student_model_name}'...") try: student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) print(f"Student model '{student_model_name}' loaded successfully.") except (OSError, ValueError) as e: print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") # Instantiate a smaller pre-trained model as the student, e.g., distilgpt2 try: student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) # Save the instantiated student model with the desired name student.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") except Exception as inst_e: print(f"Failed to instantiate and save student model: {inst_e}") sys.exit(1) # Optionally freeze teacher model parameters for param in teacher.parameters(): param.requires_grad = False # Load and prepare dataset print(f"Loading full dataset '{dataset_name}' with config '{config}'...") dataset = load_dataset(dataset_name, config) train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size) print("Data loaded and preprocessed successfully.") # Define optimizer, scheduler, and scaler for mixed precision optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) scaler = GradScaler() # Define loss criterion criterion = nn.KLDivLoss(reduction="batchmean") best_val_loss = float('inf') epochs_no_improve = 0 # Training loop for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") print("-" * 20) # Training train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) print(f"Training Loss: {train_loss:.4f}") writer.add_scalar("Loss/Train", train_loss, epoch) # Validation val_loss = validate(teacher, student, val_loader, criterion, temperature) print(f"Validation Loss: {val_loss:.4f}") writer.add_scalar("Loss/Validation", val_loss, epoch) # Check for improvement if val_loss < best_val_loss: best_val_loss = val_loss epochs_no_improve = 0 # Save the best model save_checkpoint({ 'epoch': epoch, 'model_state_dict': student.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_loss': best_val_loss }, checkpoint_dir, epoch) # Save the model as the best one student.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"Best model saved at epoch {epoch}") else: epochs_no_improve += 1 print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") if epochs_no_improve >= early_stopping_patience: print("Early stopping triggered") break # Step the scheduler scheduler.step() writer.close() print("\nFull World Model Distillation completed.") else: # Standard Language Model Distillation print(f"Starting Standard Language Model Distillation into '{student_model_name}'.") if not query_terms: print("Error: --query_terms must be provided for standard language model distillation.") sys.exit(1) # Load or instantiate student model print(f"Attempting to load student model '{student_model_name}'...") try: student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) print(f"Student model '{student_model_name}' loaded successfully.") except (OSError, ValueError) as e: print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") # Instantiate a smaller pre-trained model as the student, e.g., distilgpt2 try: student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) # Save the instantiated student model with the desired name student.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") except Exception as inst_e: print(f"Failed to instantiate and save student model: {inst_e}") sys.exit(1) # Optionally freeze teacher model parameters for param in teacher.parameters(): param.requires_grad = False # Load and prepare custom dataset print(f"Loading custom data files: {query_terms}") custom_data = load_custom_data_from_files(query_terms) train_loader, val_loader = load_custom_data( args=argparse.Namespace(max_length=max_length), tokenizer=tokenizer, custom_data=custom_data ) print("Custom data loaded and preprocessed successfully.") # Define optimizer, scheduler, and scaler for mixed precision optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) scaler = GradScaler() # Define loss criterion criterion = nn.KLDivLoss(reduction="batchmean") best_val_loss = float('inf') epochs_no_improve = 0 # Training loop for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") print("-" * 20) # Training train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) print(f"Training Loss: {train_loss:.4f}") writer.add_scalar("Loss/Train", train_loss, epoch) # Validation val_loss = validate(teacher, student, val_loader, criterion, temperature) print(f"Validation Loss: {val_loss:.4f}") writer.add_scalar("Loss/Validation", val_loss, epoch) # Check for improvement if val_loss < best_val_loss: best_val_loss = val_loss epochs_no_improve = 0 # Save the best model save_checkpoint({ 'epoch': epoch, 'model_state_dict': student.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_loss': best_val_loss }, checkpoint_dir, epoch) # Save the model as the best one student.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"Best model saved at epoch {epoch}") else: epochs_no_improve += 1 print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") if epochs_no_improve >= early_stopping_patience: print("Early stopping triggered") break # Step the scheduler scheduler.step() writer.close() print("\nStandard Language Model Distillation completed.") # ========================== # Argument Parsing # ========================== def parse_args(): parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.") # Required arguments parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model") parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model") # Dataset arguments parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset") parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')") # Mode selection parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model") # For standard distillation parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation") # Training hyperparameters parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=4, help="Batch size") parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature") # Saving and logging parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model") parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs") parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints") # Early stopping parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience") # Gradient accumulation and optimization parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer") return parser.parse_args() # ========================== # Main Function # ========================== def main(): args = parse_args() print("Arguments parsed successfully.") # Create save directories os.makedirs(args.save_path, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.checkpoint_dir, exist_ok=True) print(f"Save directory created: {args.save_path}") print(f"Log directory created: {args.log_dir}") print(f"Checkpoint directory created: {args.checkpoint_dir}") # Handle dataset loading based on distillation mode if args.distill_full_model: # Full World Model Distillation distill_model( teacher_model_name=args.teacher_model_name, student_model_name=args.student_model_name, dataset_name=args.dataset_name, config=args.config, distill_full_model=args.distill_full_model, query_terms=args.query_terms, # Not used in this mode num_epochs=args.num_epochs, batch_size=args.batch_size, max_length=args.max_length, learning_rate=args.learning_rate, temperature=args.temperature, save_path=args.save_path, log_dir=args.log_dir, checkpoint_dir=args.checkpoint_dir, early_stopping_patience=args.early_stopping_patience, accumulation_steps=args.accumulation_steps, max_grad_norm=args.max_grad_norm, weight_decay=args.weight_decay ) else: # Standard Language Model Distillation distill_model( teacher_model_name=args.teacher_model_name, student_model_name=args.student_model_name, dataset_name=args.dataset_name, config=args.config, distill_full_model=args.distill_full_model, query_terms=args.query_terms, num_epochs=args.num_epochs, batch_size=args.batch_size, max_length=args.max_length, learning_rate=args.learning_rate, temperature=args.temperature, save_path=args.save_path, log_dir=args.log_dir, checkpoint_dir=args.checkpoint_dir, early_stopping_patience=args.early_stopping_patience, accumulation_steps=args.accumulation_steps, max_grad_norm=args.max_grad_norm, weight_decay=args.weight_decay ) if __name__ == "__main__": main()