| | import torch |
| | import torch.optim as optim |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, TensorDataset |
| | from typing import Dict, List, Tuple, Optional |
| | from tqdm import tqdm |
| | import numpy as np |
| | import gc |
| | import logging |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class GRPOTrainer: |
| | def __init__( |
| | self, |
| | actor_model, |
| | reward_model, |
| | ref_model, |
| | tokenizer, |
| | learning_rate: float = 1e-6, |
| | kl_coef: float = 0.04, |
| | group_size: int = 4, |
| | clip_epsilon: float = 0.2, |
| | max_grad_norm: float = 1.0, |
| | grpo_epochs: int = 1, |
| | update_batch_size: int = 4, |
| | use_amp: bool = True, |
| | value_clip: bool = False, |
| | entropy_coef: float = 0.01, |
| | advantage_normalization: str = 'group', |
| | kl_estimation_method: str = 'forward' |
| | ): |
| | self.actor = actor_model |
| | self.reward_model = reward_model |
| | self.ref_model = ref_model |
| | self.tokenizer = tokenizer |
| | |
| | self.kl_coef = kl_coef |
| | self.group_size = group_size |
| | self.clip_epsilon = clip_epsilon |
| | self.max_grad_norm = max_grad_norm |
| | self.grpo_epochs = grpo_epochs |
| | self.update_batch_size = update_batch_size |
| | self.use_amp = use_amp |
| | self.entropy_coef = entropy_coef |
| | self.advantage_normalization = advantage_normalization |
| | self.kl_estimation_method = kl_estimation_method |
| | |
| | self.device = next(actor_model.parameters()).device |
| | |
| | |
| | self.ref_model.eval() |
| | self.ref_model.requires_grad_(False) |
| | self.reward_model.eval() |
| | self.reward_model.requires_grad_(False) |
| | |
| | |
| | self.optimizer = optim.AdamW( |
| | filter(lambda p: p.requires_grad, actor_model.parameters()), |
| | lr=learning_rate, |
| | weight_decay=0.01, |
| | betas=(0.9, 0.95), |
| | eps=1e-8 |
| | ) |
| | |
| | |
| | self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) |
| | |
| | self.training_stats = { |
| | 'iterations': 0, |
| | 'total_samples': 0, |
| | 'avg_rewards': [], |
| | 'avg_kl': [], |
| | 'policy_losses': [] |
| | } |
| | |
| | logger.info(f"GRPO Trainer initialized:") |
| | logger.info(f" Group Size: {group_size}") |
| | logger.info(f" KL Coef: {kl_coef}") |
| | logger.info(f" Clip Epsilon: {clip_epsilon}") |
| | logger.info(f" Learning Rate: {learning_rate}") |
| | logger.info(f" Update Batch Size: {update_batch_size}") |
| | logger.info(f" Mixed Precision: {use_amp}") |
| | logger.info(f" KL Estimation: {kl_estimation_method}") |
| |
|
| | def _compute_kl_divergence( |
| | self, |
| | log_probs: torch.Tensor, |
| | ref_log_probs: torch.Tensor, |
| | mask: torch.Tensor |
| | ) -> torch.Tensor: |
| |
|
| | if self.kl_estimation_method == 'forward': |
| | kl = log_probs - ref_log_probs |
| | elif self.kl_estimation_method == 'reverse': |
| | kl = ref_log_probs - log_probs |
| | else: |
| | forward_kl = log_probs - ref_log_probs |
| | reverse_kl = ref_log_probs - log_probs |
| | kl = 0.5 * (forward_kl + reverse_kl) |
| | |
| | kl_penalty = (kl * mask).sum(dim=-1) |
| | return kl_penalty |
| |
|
| | @torch.no_grad() |
| | def generate_experience( |
| | self, |
| | prompts_dataloader: DataLoader, |
| | max_gen_len: int, |
| | temperature: float = 1.0, |
| | top_p: float = 0.9 |
| | ) -> Dict: |
| |
|
| | self.actor.eval() |
| | |
| | all_sequences = [] |
| | all_log_probs = [] |
| | all_advantages = [] |
| | all_prompt_lens = [] |
| | all_rewards = [] |
| | |
| | logger.info("Generating experience...") |
| | |
| | for prompts in tqdm(prompts_dataloader, desc="Generating Experience"): |
| | try: |
| | |
| | if isinstance(prompts, (list, tuple)): |
| | prompts = prompts[0] |
| | |
| | prompts = prompts.to(self.device) |
| | batch_size = prompts.shape[0] |
| | |
| | |
| | prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0) |
| | prompt_len = prompts_repeated.shape[1] |
| | |
| | input_data = { |
| | 'segments': [{ |
| | 'type': 'text', |
| | 'data': prompts_repeated, |
| | 'modality_id': 0 |
| | }] |
| | } |
| | |
| | |
| | with torch.amp.autocast('cuda', enabled=self.use_amp): |
| | response_ids = self.actor.generate( |
| | input_data, |
| | max_new_tokens=max_gen_len, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | use_cache=True |
| | ) |
| | |
| | sequences = torch.cat([prompts_repeated, response_ids], dim=1) |
| | |
| | |
| | if sequences.shape[1] <= prompt_len: |
| | logger.warning("Generated sequence too short, skipping batch") |
| | continue |
| | |
| | full_input_data = { |
| | 'segments': [{ |
| | 'type': 'text', |
| | 'data': sequences, |
| | 'modality_id': 0 |
| | }] |
| | } |
| | |
| | |
| | with torch.amp.autocast('cuda', enabled=self.use_amp): |
| | actor_out = self.actor(full_input_data) |
| | ref_out = self.ref_model(full_input_data) |
| | |
| | logits = actor_out['logits'][:, :-1, :] |
| | ref_logits = ref_out['logits'][:, :-1, :] |
| | targets = sequences[:, 1:] |
| | |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | ref_log_probs = F.log_softmax(ref_logits, dim=-1) |
| | |
| | |
| | per_token_log_probs = torch.gather( |
| | log_probs, -1, targets.unsqueeze(-1) |
| | ).squeeze(-1) |
| | per_token_ref_log_probs = torch.gather( |
| | ref_log_probs, -1, targets.unsqueeze(-1) |
| | ).squeeze(-1) |
| | |
| | |
| | response_mask = torch.arange( |
| | sequences.size(1) - 1, device=self.device |
| | ) >= (prompt_len - 1) |
| | response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs) |
| | response_mask = response_mask.float() |
| | |
| | kl_penalty = self._compute_kl_divergence( |
| | per_token_log_probs, |
| | per_token_ref_log_probs, |
| | response_mask |
| | ) |
| | |
| | with torch.amp.autocast('cuda', enabled=self.use_amp): |
| | reward_output = self.reward_model(full_input_data) |
| | |
| | |
| | if reward_output.dim() == 2: |
| | raw_rewards = reward_output[:, -1] |
| | else: |
| | raw_rewards = reward_output.squeeze(-1) |
| | |
| | |
| | total_rewards = raw_rewards - self.kl_coef * kl_penalty |
| | |
| | |
| | rewards_grouped = total_rewards.view(batch_size, self.group_size) |
| | |
| | if self.advantage_normalization == 'group': |
| | |
| | mean_grouped = rewards_grouped.mean(dim=1, keepdim=True) |
| | std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8 |
| | advantages = (rewards_grouped - mean_grouped) / std_grouped |
| | elif self.advantage_normalization == 'global': |
| | |
| | advantages = (rewards_grouped - rewards_grouped.mean()) / ( |
| | rewards_grouped.std() + 1e-8 |
| | ) |
| | else: |
| | advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True) |
| | |
| | advantages = advantages.view(-1) |
| | |
| | |
| | all_sequences.append(sequences.cpu()) |
| | all_log_probs.append(per_token_log_probs.detach().cpu()) |
| | all_advantages.append(advantages.detach().cpu()) |
| | all_prompt_lens.append( |
| | torch.full((sequences.size(0),), prompt_len, dtype=torch.long) |
| | ) |
| | all_rewards.append(total_rewards.detach().cpu()) |
| | |
| | |
| | del logits, ref_logits, actor_out, ref_out |
| | del log_probs, ref_log_probs, reward_output |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating experience for batch: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | continue |
| | |
| | finally: |
| | torch.cuda.empty_cache() |
| | |
| | if not all_sequences: |
| | raise RuntimeError("No valid sequences generated") |
| | |
| | |
| | experience = { |
| | 'sequences': torch.cat(all_sequences, dim=0), |
| | 'log_probs': torch.cat(all_log_probs, dim=0), |
| | 'advantages': torch.cat(all_advantages, dim=0), |
| | 'prompt_lengths': torch.cat(all_prompt_lens, dim=0), |
| | 'rewards': torch.cat(all_rewards, dim=0) |
| | } |
| | |
| | |
| | logger.info(f"Generated {len(experience['sequences'])} sequences") |
| | logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}") |
| | logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}") |
| | logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}") |
| | |
| | return experience |
| |
|
| | def grpo_step( |
| | self, |
| | dataset: TensorDataset |
| | ) -> Dict[str, float]: |
| | self.actor.train() |
| | |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=self.update_batch_size, |
| | shuffle=True, |
| | drop_last=False |
| | ) |
| | |
| | epoch_stats = { |
| | 'total_loss': 0.0, |
| | 'policy_loss': 0.0, |
| | 'entropy': 0.0, |
| | 'approx_kl': 0.0, |
| | 'clip_fraction': 0.0, |
| | 'steps': 0 |
| | } |
| | |
| | for batch_data in dataloader: |
| | sequences, old_log_probs, advantages, prompt_lens = batch_data |
| | |
| | sequences = sequences.to(self.device) |
| | old_log_probs = old_log_probs.to(self.device) |
| | advantages = advantages.to(self.device) |
| | |
| | input_data = { |
| | 'segments': [{ |
| | 'type': 'text', |
| | 'data': sequences, |
| | 'modality_id': 0 |
| | }] |
| | } |
| | |
| | with torch.amp.autocast('cuda', enabled=self.use_amp): |
| | outputs = self.actor(input_data) |
| | logits = outputs['logits'][:, :-1, :] |
| | |
| | |
| | targets = sequences[:, 1:] |
| | log_probs_dist = F.log_softmax(logits, dim=-1) |
| | new_log_probs = torch.gather( |
| | log_probs_dist, -1, targets.unsqueeze(-1) |
| | ).squeeze(-1) |
| | |
| | |
| | mask = torch.zeros_like(new_log_probs) |
| | for i, pl in enumerate(prompt_lens): |
| | mask[i, pl-1:] = 1.0 |
| | |
| | |
| | ratio = torch.exp(new_log_probs - old_log_probs) |
| | |
| | |
| | adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs) |
| | |
| | |
| | surr1 = ratio * adv_expanded |
| | surr2 = torch.clamp( |
| | ratio, |
| | 1.0 - self.clip_epsilon, |
| | 1.0 + self.clip_epsilon |
| | ) * adv_expanded |
| | |
| | |
| | policy_loss = -torch.min(surr1, surr2) |
| | policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8) |
| | |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | entropy = -(probs * log_probs_dist).sum(dim=-1) |
| | entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8) |
| | |
| | |
| | loss = policy_loss - self.entropy_coef * entropy_bonus |
| | |
| | |
| | with torch.no_grad(): |
| | log_ratio = new_log_probs - old_log_probs |
| | approx_kl = ((ratio - 1) - log_ratio) * mask |
| | approx_kl = approx_kl.sum() / (mask.sum() + 1e-8) |
| | |
| | clip_fraction = ((ratio > 1 + self.clip_epsilon) | |
| | (ratio < 1 - self.clip_epsilon)).float() |
| | clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8) |
| | |
| | self.optimizer.zero_grad() |
| | self.scaler.scale(loss).backward() |
| | |
| | |
| | self.scaler.unscale_(self.optimizer) |
| | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | self.actor.parameters(), |
| | self.max_grad_norm |
| | ) |
| | |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | |
| | |
| | epoch_stats['total_loss'] += loss.item() |
| | epoch_stats['policy_loss'] += policy_loss.item() |
| | epoch_stats['entropy'] += entropy_bonus.item() |
| | epoch_stats['approx_kl'] += approx_kl.item() |
| | epoch_stats['clip_fraction'] += clip_fraction.item() |
| | epoch_stats['steps'] += 1 |
| | |
| | |
| | for key in epoch_stats: |
| | if key != 'steps': |
| | epoch_stats[key] /= max(epoch_stats['steps'], 1) |
| | |
| | return epoch_stats |
| |
|
| | def train( |
| | self, |
| | prompt_dataloader: DataLoader, |
| | num_iterations: int = 1, |
| | max_gen_len: int = 50, |
| | temperature: float = 1.0, |
| | save_every: int = 5, |
| | save_path: str = "checkpoints" |
| | ): |
| |
|
| | logger.info(f"\n{'='*80}") |
| | logger.info(f"Starting GRPO Training") |
| | logger.info(f" Iterations: {num_iterations}") |
| | logger.info(f" Max Gen Length: {max_gen_len}") |
| | logger.info(f" Temperature: {temperature}") |
| | logger.info(f"{'='*80}\n") |
| | |
| | for iteration in range(num_iterations): |
| | try: |
| | |
| | experience = self.generate_experience( |
| | prompt_dataloader, |
| | max_gen_len, |
| | temperature |
| | ) |
| | |
| | dataset = TensorDataset( |
| | experience['sequences'], |
| | experience['log_probs'], |
| | experience['advantages'], |
| | experience['prompt_lengths'] |
| | ) |
| | |
| | |
| | logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...") |
| | all_epoch_stats = [] |
| | |
| | for epoch in range(self.grpo_epochs): |
| | stats = self.grpo_step(dataset) |
| | all_epoch_stats.append(stats) |
| | |
| | logger.info( |
| | f" Epoch {epoch+1}/{self.grpo_epochs} | " |
| | f"Loss: {stats['total_loss']:.4f} | " |
| | f"KL: {stats['approx_kl']:.4f} | " |
| | f"Clip%: {stats['clip_fraction']*100:.1f}" |
| | ) |
| | |
| | |
| | avg_stats = { |
| | key: np.mean([s[key] for s in all_epoch_stats]) |
| | for key in all_epoch_stats[0].keys() |
| | } |
| | |
| | self.training_stats['iterations'] += 1 |
| | self.training_stats['total_samples'] += len(experience['sequences']) |
| | self.training_stats['avg_rewards'].append( |
| | experience['rewards'].mean().item() |
| | ) |
| | self.training_stats['avg_kl'].append(avg_stats['approx_kl']) |
| | self.training_stats['policy_losses'].append(avg_stats['policy_loss']) |
| | |
| | |
| | logger.info(f"\n{'='*80}") |
| | logger.info(f"Iteration {iteration+1}/{num_iterations} Complete") |
| | logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}") |
| | logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}") |
| | logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}") |
| | logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}") |
| | logger.info(f" Entropy: {avg_stats['entropy']:.4f}") |
| | logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%") |
| | logger.info(f"{'='*80}\n") |
| | |
| | |
| | if (iteration + 1) % save_every == 0: |
| | self.save_checkpoint( |
| | f"{save_path}/grpo_iter_{iteration+1}.pt" |
| | ) |
| | |
| | |
| | del experience, dataset |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | |
| | except Exception as e: |
| | logger.error(f"Error in iteration {iteration+1}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | continue |
| | |
| | logger.info("GRPO Training Complete!") |
| | self.print_training_summary() |
| |
|
| | def save_checkpoint(self, path: str): |
| | import os |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | |
| | checkpoint = { |
| | 'actor_state_dict': self.actor.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'scaler_state_dict': self.scaler.state_dict(), |
| | 'training_stats': self.training_stats, |
| | 'config': { |
| | 'kl_coef': self.kl_coef, |
| | 'group_size': self.group_size, |
| | 'clip_epsilon': self.clip_epsilon, |
| | } |
| | } |
| | |
| | torch.save(checkpoint, path) |
| | logger.info(f"Checkpoint saved to {path}") |
| |
|
| | def load_checkpoint(self, path: str): |
| | checkpoint = torch.load(path, map_location=self.device) |
| | |
| | self.actor.load_state_dict(checkpoint['actor_state_dict']) |
| | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | |
| | if 'scaler_state_dict' in checkpoint and self.use_amp: |
| | self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| | |
| | self.training_stats = checkpoint['training_stats'] |
| | |
| | logger.info(f"Checkpoint loaded from {path}") |
| |
|
| | def print_training_summary(self): |
| | logger.info("\n" + "="*80) |
| | logger.info("Training Summary") |
| | logger.info("="*80) |
| | logger.info(f"Total Iterations: {self.training_stats['iterations']}") |
| | logger.info(f"Total Samples: {self.training_stats['total_samples']}") |
| | |
| | if self.training_stats['avg_rewards']: |
| | logger.info( |
| | f"Final Avg Reward: " |
| | f"{self.training_stats['avg_rewards'][-1]:.4f}" |
| | ) |
| | logger.info( |
| | f"Reward Improvement: " |
| | f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}" |
| | ) |
| | |
| | logger.info("="*80 + "\n") |