|
|
from typing import Callable, Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
class HomogeneousSequential(nn.Sequential): |
|
|
""" |
|
|
HomogenousSequential is a sequential container that requires all child modules |
|
|
to be of the same type and have matching input/output shapes. In turn, it may be |
|
|
compiled with the `scan` higher order operator to save compile time. |
|
|
""" |
|
|
|
|
|
repeated_layer: type |
|
|
"""The type of the layer being looped over.""" |
|
|
|
|
|
def __init__(self, *args: nn.Module) -> None: |
|
|
super().__init__(*args) |
|
|
types = set(type(module) for module in args) |
|
|
assert len(types) == 1, f"All modules must be of the same type. Got {types}" |
|
|
self.repeated_layer = types.pop() |
|
|
|
|
|
def forward(self, *input, **broadcasted_inputs): |
|
|
""" |
|
|
Much like `torch.nn.Sequential`, this takes `input` and forwards it to the |
|
|
first module it contains. It then "chains" outputs to inputs sequentially for |
|
|
each subsequent module, finally returning the output of the last module. |
|
|
Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via |
|
|
keyword arguments. The same keyword arguments will be passed to every layer |
|
|
without changes (i.e. "broadcasted"). |
|
|
""" |
|
|
for module in self: |
|
|
input = module(*splat(input), **broadcasted_inputs) |
|
|
return input |
|
|
|
|
|
|
|
|
def splat(input): |
|
|
if not isinstance(input, list | tuple): |
|
|
input = (input,) |
|
|
return input |
|
|
|
|
|
|
|
|
@dataclass(kw_only=True) |
|
|
class RopeScaling: |
|
|
""" |
|
|
RoPE scaling parameters. The defaults are what was selected in Llama 3.1. |
|
|
""" |
|
|
factor: float = 8.0 |
|
|
low_freq_factor: float = 1.0 |
|
|
high_freq_factor: float = 4.0 |
|
|
original_context_len: int = 8192 |
|
|
|
|
|
|
|
|
def default_rope_frequencies( |
|
|
head_dim: int, |
|
|
theta: float = 10000.0, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Computes the original RoPE frequencies in e.g. Llama 2. |
|
|
Args: |
|
|
head_dim: the size of a single attention head. |
|
|
theta: a hyperparameter controlling how fast the embeddings rotate. |
|
|
Returns: |
|
|
The frequencies for the RoPE embeddings. |
|
|
""" |
|
|
return 1.0 / ( |
|
|
theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim) |
|
|
) |
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
|
|
Args: |
|
|
q (`torch.Tensor`): The query tensor. |
|
|
k (`torch.Tensor`): The key tensor. |
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
|
position_ids (`torch.Tensor`, *optional*): |
|
|
Deprecated and unused. |
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
|
Returns: |
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
|
""" |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
|
|
|
def transition(x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1): |
|
|
"""Apply masking to input tokens. If mask_block_size > 1, use block masking for all rows.""" |
|
|
|
|
|
if mask_block_size == 1: |
|
|
|
|
|
|
|
|
move_indices = ( |
|
|
torch.rand(*x_0.shape, device=x_0.device) < sigma |
|
|
) & maskable_mask |
|
|
x_t = torch.where(move_indices, mask_token_id, x_0) |
|
|
return x_t |
|
|
|
|
|
|
|
|
return block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size) |
|
|
|
|
|
|
|
|
def block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size): |
|
|
""" |
|
|
XLA-compatible block masking applied uniformly to all rows in the batch. |
|
|
Uses efficient tensor operations to avoid dynamic loops. |
|
|
""" |
|
|
batch_size, seq_len = x_0.shape |
|
|
|
|
|
if seq_len < mask_block_size: |
|
|
return x_0 |
|
|
|
|
|
|
|
|
num_windows = seq_len - mask_block_size + 1 |
|
|
|
|
|
|
|
|
window_starts = torch.arange(num_windows, device=x_0.device) |
|
|
block_offsets = torch.arange(mask_block_size, device=x_0.device) |
|
|
all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) |
|
|
|
|
|
|
|
|
maskable_blocks = ( |
|
|
maskable_mask.unsqueeze(1) |
|
|
.expand(-1, num_windows, -1) |
|
|
.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1)) |
|
|
) |
|
|
fully_maskable = maskable_blocks.all(dim=2) |
|
|
|
|
|
|
|
|
effective_sigma = 1 - (1 - sigma) ** ( |
|
|
1 / mask_block_size |
|
|
) |
|
|
should_mask = ( |
|
|
torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma |
|
|
) & fully_maskable |
|
|
|
|
|
|
|
|
|
|
|
position_indices = torch.arange(seq_len, device=x_0.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
position_indices = position_indices.unsqueeze(0).unsqueeze(0) |
|
|
all_positions = all_positions.unsqueeze(0) |
|
|
should_mask = should_mask.unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
position_matches = (position_indices == all_positions.unsqueeze(3)).any( |
|
|
dim=2 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
should_mask_positions = should_mask & position_matches |
|
|
|
|
|
|
|
|
final_mask = should_mask_positions.any(dim=1) |
|
|
|
|
|
|
|
|
result = torch.where(final_mask, mask_token_id, x_0) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def prefix_input_ids(input_ids, maskable_mask, apply_prefix): |
|
|
"""Apply prefix to input_ids based on configured probability. Return a masksable mask such that the prefix is not masked.""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) |
|
|
|
|
|
position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( |
|
|
0 |
|
|
) |
|
|
|
|
|
prefix_mask = position_indices < prefix_lengths.unsqueeze( |
|
|
1 |
|
|
) |
|
|
|
|
|
maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask) |
|
|
return maskable_mask |
|
|
|
|
|
|
|
|
def truncate_input_ids(input_ids, apply_truncate, pad_token_id): |
|
|
"""Truncate input_ids at random position and fill with pad token. Return the input_ids with suffix truncated and filled with pad token.""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
truncate_positions = torch.randint( |
|
|
1, seq_len, (batch_size,), device=input_ids.device |
|
|
) |
|
|
|
|
|
position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( |
|
|
0 |
|
|
) |
|
|
|
|
|
truncate_mask = position_indices >= truncate_positions.unsqueeze( |
|
|
1 |
|
|
) |
|
|
|
|
|
input_ids = torch.where( |
|
|
apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids |
|
|
) |
|
|
return input_ids |
|
|
|