| | from typing import Callable, Optional, Tuple, Union |
| | from dataclasses import dataclass |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | class HomogeneousSequential(nn.Sequential): |
| | """ |
| | HomogenousSequential is a sequential container that requires all child modules |
| | to be of the same type and have matching input/output shapes. In turn, it may be |
| | compiled with the `scan` higher order operator to save compile time. |
| | """ |
| |
|
| | repeated_layer: type |
| | """The type of the layer being looped over.""" |
| |
|
| | def __init__(self, *args: nn.Module) -> None: |
| | super().__init__(*args) |
| | types = set(type(module) for module in args) |
| | assert len(types) == 1, f"All modules must be of the same type. Got {types}" |
| | self.repeated_layer = types.pop() |
| |
|
| | def forward(self, *input, **broadcasted_inputs): |
| | """ |
| | Much like `torch.nn.Sequential`, this takes `input` and forwards it to the |
| | first module it contains. It then "chains" outputs to inputs sequentially for |
| | each subsequent module, finally returning the output of the last module. |
| | Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via |
| | keyword arguments. The same keyword arguments will be passed to every layer |
| | without changes (i.e. "broadcasted"). |
| | """ |
| | for module in self: |
| | input = module(*splat(input), **broadcasted_inputs) |
| | return input |
| |
|
| |
|
| | def splat(input): |
| | if not isinstance(input, list | tuple): |
| | input = (input,) |
| | return input |
| |
|
| |
|
| | @dataclass(kw_only=True) |
| | class RopeScaling: |
| | """ |
| | RoPE scaling parameters. The defaults are what was selected in Llama 3.1. |
| | """ |
| | factor: float = 8.0 |
| | low_freq_factor: float = 1.0 |
| | high_freq_factor: float = 4.0 |
| | original_context_len: int = 8192 |
| |
|
| |
|
| | def default_rope_frequencies( |
| | head_dim: int, |
| | theta: float = 10000.0, |
| | ) -> torch.Tensor: |
| | """ |
| | Computes the original RoPE frequencies in e.g. Llama 2. |
| | Args: |
| | head_dim: the size of a single attention head. |
| | theta: a hyperparameter controlling how fast the embeddings rotate. |
| | Returns: |
| | The frequencies for the RoPE embeddings. |
| | """ |
| | return 1.0 / ( |
| | theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim) |
| | ) |
| |
|
| | def rotate_half(x): |
| | """Rotates half the hidden dims of the input.""" |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| | """Applies Rotary Position Embedding to the query and key tensors. |
| | |
| | Args: |
| | q (`torch.Tensor`): The query tensor. |
| | k (`torch.Tensor`): The key tensor. |
| | cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| | sin (`torch.Tensor`): The sine part of the rotary embedding. |
| | position_ids (`torch.Tensor`, *optional*): |
| | Deprecated and unused. |
| | unsqueeze_dim (`int`, *optional*, defaults to 1): |
| | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| | Returns: |
| | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| | """ |
| | cos = cos.unsqueeze(unsqueeze_dim) |
| | sin = sin.unsqueeze(unsqueeze_dim) |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| |
|
| |
|
| | def transition(x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1): |
| | """Apply masking to input tokens. If mask_block_size > 1, use block masking for all rows.""" |
| |
|
| | if mask_block_size == 1: |
| | |
| | |
| | move_indices = ( |
| | torch.rand(*x_0.shape, device=x_0.device) < sigma |
| | ) & maskable_mask |
| | x_t = torch.where(move_indices, mask_token_id, x_0) |
| | return x_t |
| |
|
| | |
| | return block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size) |
| |
|
| |
|
| | def block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size): |
| | """ |
| | XLA-compatible block masking applied uniformly to all rows in the batch. |
| | Uses efficient tensor operations to avoid dynamic loops. |
| | """ |
| | batch_size, seq_len = x_0.shape |
| |
|
| | if seq_len < mask_block_size: |
| | return x_0 |
| |
|
| | |
| | num_windows = seq_len - mask_block_size + 1 |
| |
|
| | |
| | window_starts = torch.arange(num_windows, device=x_0.device) |
| | block_offsets = torch.arange(mask_block_size, device=x_0.device) |
| | all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) |
| |
|
| | |
| | maskable_blocks = ( |
| | maskable_mask.unsqueeze(1) |
| | .expand(-1, num_windows, -1) |
| | .gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1)) |
| | ) |
| | fully_maskable = maskable_blocks.all(dim=2) |
| |
|
| | |
| | effective_sigma = 1 - (1 - sigma) ** ( |
| | 1 / mask_block_size |
| | ) |
| | should_mask = ( |
| | torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma |
| | ) & fully_maskable |
| |
|
| | |
| | |
| | position_indices = torch.arange(seq_len, device=x_0.device) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | position_indices = position_indices.unsqueeze(0).unsqueeze(0) |
| | all_positions = all_positions.unsqueeze(0) |
| | should_mask = should_mask.unsqueeze(2) |
| |
|
| | |
| | |
| | position_matches = (position_indices == all_positions.unsqueeze(3)).any( |
| | dim=2 |
| | ) |
| |
|
| | |
| | |
| | should_mask_positions = should_mask & position_matches |
| |
|
| | |
| | final_mask = should_mask_positions.any(dim=1) |
| |
|
| | |
| | result = torch.where(final_mask, mask_token_id, x_0) |
| |
|
| | return result |
| |
|
| |
|
| | def prefix_input_ids(input_ids, maskable_mask, apply_prefix): |
| | """Apply prefix to input_ids based on configured probability. Return a masksable mask such that the prefix is not masked.""" |
| | batch_size, seq_len = input_ids.shape |
| | |
| | prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) |
| | |
| | position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( |
| | 0 |
| | ) |
| | |
| | prefix_mask = position_indices < prefix_lengths.unsqueeze( |
| | 1 |
| | ) |
| | |
| | maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask) |
| | return maskable_mask |
| |
|
| |
|
| | def truncate_input_ids(input_ids, apply_truncate, pad_token_id): |
| | """Truncate input_ids at random position and fill with pad token. Return the input_ids with suffix truncated and filled with pad token.""" |
| | batch_size, seq_len = input_ids.shape |
| | |
| | truncate_positions = torch.randint( |
| | 1, seq_len, (batch_size,), device=input_ids.device |
| | ) |
| | |
| | position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( |
| | 0 |
| | ) |
| | |
| | truncate_mask = position_indices >= truncate_positions.unsqueeze( |
| | 1 |
| | ) |
| | |
| | input_ids = torch.where( |
| | apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids |
| | ) |
| | return input_ids |
| |
|