| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import torch, einops |
| |
|
| | |
| | class HookPoint(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | self.fwd_hooks = [] |
| | self.bwd_hooks = [] |
| | |
| | def give_name(self, name): |
| | |
| | self.name = name |
| | |
| | def add_hook(self, hook, dir='fwd'): |
| | |
| | |
| | |
| | def full_hook(module, module_input, module_output): |
| | |
| | return hook(module_output, name=self.name) |
| | |
| | if dir == 'fwd': |
| | |
| | handle = self.register_forward_hook(full_hook) |
| | self.fwd_hooks.append(handle) |
| | elif dir == 'bwd': |
| | |
| | handle = self.register_backward_hook(full_hook) |
| | self.bwd_hooks.append(handle) |
| | else: |
| | raise ValueError(f"Invalid direction {dir}") |
| | |
| | def remove_hooks(self, dir='fwd'): |
| | |
| | if (dir == 'fwd') or (dir == 'both'): |
| | for hook in self.fwd_hooks: |
| | hook.remove() |
| | self.fwd_hooks = [] |
| | if (dir == 'bwd') or (dir == 'both'): |
| | for hook in self.bwd_hooks: |
| | hook.remove() |
| | self.bwd_hooks = [] |
| | if dir not in ['fwd', 'bwd', 'both']: |
| | raise ValueError(f"Invalid direction {dir}") |
| | |
| | def forward(self, x): |
| | |
| | return x |
| |
|
| | |
| | class Embed(nn.Module): |
| | def __init__(self, d_vocab, d_model, embed_type='one_hot'): |
| | super().__init__() |
| | self.d_vocab = d_vocab |
| | self.embed_type = embed_type |
| | |
| | if embed_type == 'learned': |
| | |
| | self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model)) |
| | elif embed_type == 'one_hot': |
| | |
| | self.W_E = None |
| | else: |
| | raise ValueError(f"Invalid embed_type: {embed_type}. Must be 'one_hot' or 'learned'") |
| | |
| | def forward(self, x): |
| | |
| | |
| | |
| | if isinstance(x, list): |
| | device = self.W_E.device if self.W_E is not None else 'cpu' |
| | x = torch.tensor(x, device=device) |
| | |
| | assert x.ndim == 2 and x.shape[1] == 2, f"Expected input shape (batch_size, 2), got {x.shape}" |
| |
|
| | if self.embed_type == 'one_hot': |
| | |
| | embed = F.one_hot(x, num_classes=self.d_vocab).float().sum(dim=1).unsqueeze(1) |
| | elif self.embed_type == 'learned': |
| | |
| | embed = torch.einsum('dbp -> bpd', self.W_E[:, x]).sum(dim=1).unsqueeze(1) |
| | |
| | return embed |
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, d_model, epsilon=1e-4, model=[None]): |
| | super().__init__() |
| | self.model = model |
| | |
| | self.w_ln = nn.Parameter(torch.ones(d_model)) |
| | self.b_ln = nn.Parameter(torch.zeros(d_model)) |
| | self.epsilon = epsilon |
| | |
| | def forward(self, x): |
| | if self.model[0].use_ln: |
| | |
| | x = x - x.mean(axis=-1)[..., None] |
| | x = x / (x.std(axis=-1)[..., None] + self.epsilon) |
| | |
| | x = x * self.w_ln |
| | x = x + self.b_ln |
| | return x |
| | else: |
| | return x |
| |
|
| | |
| | class MLP(nn.Module): |
| | def __init__(self, d_model, d_mlp, d_vocab, act_type, model, init_type='random', init_scale=0.1): |
| | super().__init__() |
| | self.model = model |
| | self.init_type = init_type |
| | self.init_scale = init_scale |
| | |
| | |
| | if init_type == 'random': |
| | |
| | self.W_in = nn.Parameter(self.init_scale * torch.randn(d_mlp, d_model)/np.sqrt(d_model)) |
| | self.W_out = nn.Parameter(self.init_scale * torch.randn(d_vocab, d_mlp)/np.sqrt(d_model)) |
| | elif init_type == 'single-freq': |
| | |
| | freq_num = (d_vocab-1)//2 |
| | init_freq = decide_frequencies(d_mlp, d_model, freq_num) |
| | fourier_basis, _ = get_fourier_basis(d_vocab) |
| | |
| | self.W_in = nn.Parameter(self.init_scale * np.sqrt(d_vocab/2) * sparse_initialization(d_mlp, d_model, init_freq) @ fourier_basis) |
| | self.W_out = nn.Parameter(self.init_scale * np.sqrt(d_vocab/2) * fourier_basis.T @ sparse_initialization(d_mlp, d_model, init_freq).T) |
| | else: |
| | raise ValueError(f"Invalid init_type: ini{init_type}. Must be 'random' or 'single-freq'") |
| | |
| | |
| | self.act_type = act_type |
| | self.hook_pre = HookPoint() |
| | self.hook_post = HookPoint() |
| | |
| | |
| | if isinstance(act_type, str): |
| | assert act_type in ['ReLU', 'GeLU', 'Quad', 'Id'], f"Invalid activation type: {act_type}" |
| | elif not callable(act_type): |
| | raise ValueError("act_type must be either a string ('ReLU', 'GeLU', 'Quad', 'Id') or a callable function") |
| | |
| | fourier_basis, _ = get_fourier_basis(d_vocab) |
| | self.register_buffer('basis', fourier_basis.clone().detach()) |
| | |
| | def forward(self, x): |
| | |
| | x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x)) |
| | |
| | |
| | if callable(self.act_type): |
| | |
| | x = self.act_type(x) |
| | elif self.act_type == 'ReLU': |
| | x = F.relu(x) |
| | elif self.act_type == 'GeLU': |
| | x = F.gelu(x) |
| | elif self.act_type == "Quad": |
| | x = torch.square(x) |
| | elif self.act_type == "Id": |
| | x = x |
| | |
| | x = self.hook_post(x) |
| | |
| | x = torch.einsum('dm,bpm->bpd', self.W_out, x) |
| | return x |
| |
|
| | class EmbedMLP(nn.Module): |
| | def __init__(self, d_vocab, d_model, d_mlp, act_type, use_cache=False, use_ln=True, init_type='random', init_scale=0.1, embed_type='one_hot'): |
| | super().__init__() |
| | self.cache = {} |
| | self.use_cache = use_cache |
| | self.init_type = init_type |
| |
|
| | |
| | self.embed = Embed(d_vocab, d_model, embed_type=embed_type) |
| | self.mlp = MLP(d_model, d_mlp, d_vocab, act_type, model=[self], init_type=init_type, init_scale=init_scale) |
| | |
| | |
| | |
| | |
| | self.use_ln = use_ln |
| |
|
| | |
| | for name, module in self.named_modules(): |
| | if type(module) == HookPoint: |
| | module.give_name(name) |
| | |
| | def forward(self, x): |
| | |
| | x = self.embed(x) |
| | |
| | x = self.mlp(x) |
| | |
| | |
| | |
| | |
| | return x.squeeze(1) |
| |
|
| | def set_use_cache(self, use_cache): |
| | self.use_cache = use_cache |
| | |
| | def hook_points(self): |
| | |
| | return [module for name, module in self.named_modules() if 'hook' in name] |
| |
|
| | def remove_all_hooks(self): |
| | |
| | for hp in self.hook_points(): |
| | hp.remove_hooks('fwd') |
| | hp.remove_hooks('bwd') |
| | |
| | def cache_all(self, cache, incl_bwd=False): |
| | |
| | def save_hook(tensor, name): |
| | cache[name] = tensor.detach() |
| | def save_hook_back(tensor, name): |
| | cache[name + '_grad'] = tensor[0].detach() |
| | for hp in self.hook_points(): |
| | hp.add_hook(save_hook, 'fwd') |
| | if incl_bwd: |
| | hp.add_hook(save_hook_back, 'bwd') |
| |
|
| |
|
| | |
| | def get_fourier_basis(p): |
| | |
| | fourier_basis = [] |
| | fourier_basis_names = [] |
| |
|
| | |
| | fourier_basis.append(torch.ones(p) / np.sqrt(p)) |
| | fourier_basis_names.append('Const') |
| |
|
| | |
| | for i in range(1, p // 2 + 1): |
| | |
| | cosine = torch.cos(2 * torch.pi * torch.arange(p) * i / p) |
| | sine = torch.sin(2 * torch.pi * torch.arange(p) * i / p) |
| | |
| | cosine /= cosine.norm() |
| | sine /= sine.norm() |
| | |
| | fourier_basis.append(cosine) |
| | fourier_basis.append(sine) |
| | fourier_basis_names.append(f'cos {i}') |
| | fourier_basis_names.append(f'sin {i}') |
| | |
| | |
| | if p % 2 == 0: |
| | cosine = torch.cos(torch.pi * torch.arange(p)) |
| | cosine /= cosine.norm() |
| | fourier_basis.append(cosine) |
| | fourier_basis_names.append(f'cos {p // 2}') |
| | |
| | |
| | fourier_basis = torch.stack(fourier_basis, dim=0) |
| | |
| | return fourier_basis, fourier_basis_names |
| |
|
| | def decide_frequencies(d_mlp, d_model, freq_num): |
| | """ |
| | Decide frequency assignments for each neuron. |
| | |
| | For a weight matrix of shape (d_mlp, d_model), valid frequencies are integers |
| | in the range [1, (d_model-1)//2]. This function samples 'freq_num' unique frequencies |
| | uniformly from this range and assigns them to the neurons as equally as possible. |
| | |
| | Args: |
| | d_mlp (int): Number of neurons (rows). |
| | d_model (int): Number of columns in the weight matrix. |
| | freq_num (int): Number of unique frequencies to sample. |
| | |
| | Returns: |
| | np.ndarray: A 1D array of length d_mlp containing the frequency assigned to each neuron. |
| | """ |
| | |
| | max_freq = (d_model - 1) // 2 |
| | if freq_num > max_freq: |
| | raise ValueError(f"freq_num ({freq_num}) cannot exceed the number of available frequencies ({max_freq}).") |
| | |
| | |
| | freq_choices = np.random.choice(np.arange(1, max_freq + 1), size=freq_num, replace=False) |
| | |
| | |
| | |
| | repeats = (d_mlp + freq_num - 1) // freq_num |
| | freq_assignments = np.tile(freq_choices, repeats)[:d_mlp] |
| | |
| | |
| | np.random.shuffle(freq_assignments) |
| | |
| | return freq_assignments |
| |
|
| | def sparse_initialization(d_mlp, d_model, freq_assignments): |
| | """ |
| | Generate a sparse weight matrix using the provided frequency assignments. |
| | |
| | For each neuron (row) assigned frequency f, this function assigns Gaussian random values |
| | to columns (2*f - 1) and (2*f) of that row. All other entries remain zero. |
| | |
| | Args: |
| | d_mlp (int): Number of neurons (rows) in the weight matrix. |
| | d_model (int): Number of columns in the weight matrix. |
| | freq_assignments (np.ndarray): 1D array of length d_mlp containing the frequency for each neuron. |
| | |
| | Returns: |
| | torch.Tensor: A weight matrix of shape (d_mlp, d_model) with the sparse initialization. |
| | """ |
| | |
| | weight = torch.zeros(d_mlp, d_model) |
| | |
| | |
| | for i, f in enumerate(freq_assignments): |
| | col1 = 2 * f - 1 |
| | col2 = 2 * f |
| | |
| | if col2 < d_model: |
| | vec = torch.randn(2, device=weight.device, dtype=weight.dtype) |
| | |
| | vec = vec / torch.norm(vec, p=2) |
| | |
| | |
| | weight[i, col1] = vec[0] |
| | weight[i, col2] = vec[1] |
| | else: |
| | |
| | raise IndexError(f"Computed column index {col2} is out of bounds for d_model={d_model}.") |
| | |
| | return weight |