| | """ |
| | PlainMLP vs ResMLP Comparison on Distant Identity Task |
| | |
| | This experiment demonstrates the vanishing gradient problem in deep networks |
| | and how residual connections solve it. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from typing import Dict, List, Tuple |
| | import json |
| |
|
| | |
| | torch.manual_seed(42) |
| | np.random.seed(42) |
| |
|
| | |
| | NUM_LAYERS = 20 |
| | HIDDEN_DIM = 64 |
| | NUM_SAMPLES = 1024 |
| | TRAINING_STEPS = 500 |
| | LEARNING_RATE = 1e-3 |
| | BATCH_SIZE = 64 |
| |
|
| | print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}") |
| | print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}") |
| |
|
| |
|
| | class PlainMLP(nn.Module): |
| | """Plain MLP: x = ReLU(Linear(x)) for each layer""" |
| | |
| | def __init__(self, dim: int, num_layers: int): |
| | super().__init__() |
| | self.layers = nn.ModuleList() |
| | for _ in range(num_layers): |
| | layer = nn.Linear(dim, dim) |
| | |
| | nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
| | nn.init.zeros_(layer.bias) |
| | self.layers.append(layer) |
| | self.activation = nn.ReLU() |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for layer in self.layers: |
| | x = self.activation(layer(x)) |
| | return x |
| |
|
| |
|
| | class ResMLP(nn.Module): |
| | """Residual MLP: x = x + ReLU(Linear(x)) for each layer""" |
| | |
| | def __init__(self, dim: int, num_layers: int): |
| | super().__init__() |
| | self.layers = nn.ModuleList() |
| | for _ in range(num_layers): |
| | layer = nn.Linear(dim, dim) |
| | |
| | nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
| | nn.init.zeros_(layer.bias) |
| | self.layers.append(layer) |
| | self.activation = nn.ReLU() |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for layer in self.layers: |
| | x = x + self.activation(layer(x)) |
| | return x |
| |
|
| |
|
| | def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Generate synthetic data where Y = X, with X ~ U(-1, 1)""" |
| | X = torch.empty(num_samples, dim).uniform_(-1, 1) |
| | Y = X.clone() |
| | return X, Y |
| |
|
| |
|
| | def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor, |
| | steps: int, lr: float, batch_size: int) -> List[float]: |
| | """Train model and record loss at each step""" |
| | optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
| | criterion = nn.MSELoss() |
| | losses = [] |
| | |
| | num_samples = X.shape[0] |
| | |
| | for step in range(steps): |
| | |
| | indices = torch.randint(0, num_samples, (batch_size,)) |
| | batch_x = X[indices] |
| | batch_y = Y[indices] |
| | |
| | |
| | optimizer.zero_grad() |
| | output = model(batch_x) |
| | loss = criterion(output, batch_y) |
| | |
| | |
| | loss.backward() |
| | optimizer.step() |
| | |
| | losses.append(loss.item()) |
| | |
| | if step % 100 == 0: |
| | print(f" Step {step}/{steps}, Loss: {loss.item():.6f}") |
| | |
| | return losses |
| |
|
| |
|
| | class ActivationGradientHook: |
| | """Hook to capture activations and gradients at each layer""" |
| | |
| | def __init__(self): |
| | self.activations: List[torch.Tensor] = [] |
| | self.gradients: List[torch.Tensor] = [] |
| | self.handles = [] |
| | |
| | def register_hooks(self, model: nn.Module): |
| | """Register forward and backward hooks on each layer""" |
| | for layer in model.layers: |
| | |
| | handle_fwd = layer.register_forward_hook(self._forward_hook) |
| | |
| | handle_bwd = layer.register_full_backward_hook(self._backward_hook) |
| | self.handles.extend([handle_fwd, handle_bwd]) |
| | |
| | def _forward_hook(self, module, input, output): |
| | self.activations.append(output.detach().clone()) |
| | |
| | def _backward_hook(self, module, grad_input, grad_output): |
| | |
| | self.gradients.append(grad_output[0].detach().clone()) |
| | |
| | def clear(self): |
| | self.activations = [] |
| | self.gradients = [] |
| | |
| | def remove_hooks(self): |
| | for handle in self.handles: |
| | handle.remove() |
| | self.handles = [] |
| | |
| | def get_activation_stats(self) -> Tuple[List[float], List[float]]: |
| | """Get mean and std of activations for each layer""" |
| | means = [act.mean().item() for act in self.activations] |
| | stds = [act.std().item() for act in self.activations] |
| | return means, stds |
| | |
| | def get_gradient_norms(self) -> List[float]: |
| | """Get L2 norm of gradients for each layer""" |
| | |
| | norms = [grad.norm(2).item() for grad in reversed(self.gradients)] |
| | return norms |
| |
|
| |
|
| | def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict: |
| | """Perform forward/backward pass and capture activation/gradient stats""" |
| | hook = ActivationGradientHook() |
| | hook.register_hooks(model) |
| | |
| | |
| | X_test = torch.empty(batch_size, dim).uniform_(-1, 1) |
| | Y_test = X_test.clone() |
| | |
| | |
| | model.zero_grad() |
| | output = model(X_test) |
| | loss = nn.MSELoss()(output, Y_test) |
| | |
| | |
| | loss.backward() |
| | |
| | |
| | act_means, act_stds = hook.get_activation_stats() |
| | grad_norms = hook.get_gradient_norms() |
| | |
| | hook.remove_hooks() |
| | |
| | return { |
| | 'activation_means': act_means, |
| | 'activation_stds': act_stds, |
| | 'gradient_norms': grad_norms, |
| | 'final_loss': loss.item() |
| | } |
| |
|
| |
|
| | def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str): |
| | """Plot training loss curves for both models""" |
| | plt.figure(figsize=(10, 6)) |
| | steps = range(len(plain_losses)) |
| | |
| | plt.plot(steps, plain_losses, label='PlainMLP', color='red', alpha=0.8) |
| | plt.plot(steps, res_losses, label='ResMLP', color='blue', alpha=0.8) |
| | |
| | plt.xlabel('Training Steps', fontsize=12) |
| | plt.ylabel('MSE Loss', fontsize=12) |
| | plt.title('Training Loss: PlainMLP vs ResMLP on Identity Task', fontsize=14) |
| | plt.legend(fontsize=11) |
| | plt.grid(True, alpha=0.3) |
| | plt.yscale('log') |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"[Plot] Saved training loss plot to {save_path}") |
| |
|
| |
|
| | def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str): |
| | """Plot gradient magnitude vs layer depth""" |
| | plt.figure(figsize=(10, 6)) |
| | layers = range(1, len(plain_grads) + 1) |
| | |
| | plt.plot(layers, plain_grads, 'o-', label='PlainMLP', color='red', markersize=6) |
| | plt.plot(layers, res_grads, 's-', label='ResMLP', color='blue', markersize=6) |
| | |
| | plt.xlabel('Layer Depth', fontsize=12) |
| | plt.ylabel('Gradient L2 Norm', fontsize=12) |
| | plt.title('Gradient Magnitude vs Layer Depth (After Training)', fontsize=14) |
| | plt.legend(fontsize=11) |
| | plt.grid(True, alpha=0.3) |
| | plt.yscale('log') |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"[Plot] Saved gradient magnitude plot to {save_path}") |
| |
|
| |
|
| | def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str): |
| | """Plot activation mean vs layer depth""" |
| | plt.figure(figsize=(10, 6)) |
| | layers = range(1, len(plain_means) + 1) |
| | |
| | plt.plot(layers, plain_means, 'o-', label='PlainMLP', color='red', markersize=6) |
| | plt.plot(layers, res_means, 's-', label='ResMLP', color='blue', markersize=6) |
| | |
| | plt.xlabel('Layer Depth', fontsize=12) |
| | plt.ylabel('Activation Mean', fontsize=12) |
| | plt.title('Activation Mean vs Layer Depth (After Training)', fontsize=14) |
| | plt.legend(fontsize=11) |
| | plt.grid(True, alpha=0.3) |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"[Plot] Saved activation mean plot to {save_path}") |
| |
|
| |
|
| | def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str): |
| | """Plot activation std vs layer depth""" |
| | plt.figure(figsize=(10, 6)) |
| | layers = range(1, len(plain_stds) + 1) |
| | |
| | plt.plot(layers, plain_stds, 'o-', label='PlainMLP', color='red', markersize=6) |
| | plt.plot(layers, res_stds, 's-', label='ResMLP', color='blue', markersize=6) |
| | |
| | plt.xlabel('Layer Depth', fontsize=12) |
| | plt.ylabel('Activation Std', fontsize=12) |
| | plt.title('Activation Standard Deviation vs Layer Depth (After Training)', fontsize=14) |
| | plt.legend(fontsize=11) |
| | plt.grid(True, alpha=0.3) |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"[Plot] Saved activation std plot to {save_path}") |
| |
|
| |
|
| | def main(): |
| | print("=" * 60) |
| | print("PlainMLP vs ResMLP: Distant Identity Task Experiment") |
| | print("=" * 60) |
| | |
| | |
| | print("\n[1] Generating synthetic identity data...") |
| | X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM) |
| | print(f" Data shape: X={X.shape}, Y={Y.shape}") |
| | print(f" X range: [{X.min():.3f}, {X.max():.3f}]") |
| | |
| | |
| | print("\n[2] Initializing models...") |
| | plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS) |
| | res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS) |
| | |
| | plain_params = sum(p.numel() for p in plain_mlp.parameters()) |
| | res_params = sum(p.numel() for p in res_mlp.parameters()) |
| | print(f" PlainMLP parameters: {plain_params:,}") |
| | print(f" ResMLP parameters: {res_params:,}") |
| | |
| | |
| | print("\n[3] Training PlainMLP...") |
| | plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
| | print(f" Final loss: {plain_losses[-1]:.6f}") |
| | |
| | |
| | print("\n[4] Training ResMLP...") |
| | res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
| | print(f" Final loss: {res_losses[-1]:.6f}") |
| | |
| | |
| | print("\n[5] Analyzing final state of trained models...") |
| | print(" Analyzing PlainMLP...") |
| | plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM) |
| | print(" Analyzing ResMLP...") |
| | res_stats = analyze_final_state(res_mlp, HIDDEN_DIM) |
| | |
| | |
| | print("\n[6] Analysis Summary:") |
| | print(f" PlainMLP - Final Loss: {plain_stats['final_loss']:.6f}") |
| | print(f" ResMLP - Final Loss: {res_stats['final_loss']:.6f}") |
| | print(f" PlainMLP - Gradient norm range: [{min(plain_stats['gradient_norms']):.2e}, {max(plain_stats['gradient_norms']):.2e}]") |
| | print(f" ResMLP - Gradient norm range: [{min(res_stats['gradient_norms']):.2e}, {max(res_stats['gradient_norms']):.2e}]") |
| | |
| | |
| | print("\n[7] Generating plots...") |
| | plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png') |
| | plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'], |
| | 'plots/gradient_magnitude.png') |
| | plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'], |
| | 'plots/activation_mean.png') |
| | plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'], |
| | 'plots/activation_std.png') |
| | |
| | |
| | results = { |
| | 'config': { |
| | 'num_layers': NUM_LAYERS, |
| | 'hidden_dim': HIDDEN_DIM, |
| | 'num_samples': NUM_SAMPLES, |
| | 'training_steps': TRAINING_STEPS, |
| | 'learning_rate': LEARNING_RATE, |
| | 'batch_size': BATCH_SIZE |
| | }, |
| | 'plain_mlp': { |
| | 'final_loss': plain_losses[-1], |
| | 'initial_loss': plain_losses[0], |
| | 'gradient_norms': plain_stats['gradient_norms'], |
| | 'activation_means': plain_stats['activation_means'], |
| | 'activation_stds': plain_stats['activation_stds'] |
| | }, |
| | 'res_mlp': { |
| | 'final_loss': res_losses[-1], |
| | 'initial_loss': res_losses[0], |
| | 'gradient_norms': res_stats['gradient_norms'], |
| | 'activation_means': res_stats['activation_means'], |
| | 'activation_stds': res_stats['activation_stds'] |
| | } |
| | } |
| | |
| | with open('results.json', 'w') as f: |
| | json.dump(results, f, indent=2) |
| | print("\n[8] Results saved to results.json") |
| | |
| | print("\n" + "=" * 60) |
| | print("Experiment completed successfully!") |
| | print("=" * 60) |
| | |
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | results = main() |
| |
|