| | import math |
| | import os |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | from huggingface_hub import hf_hub_download |
| | import lm_eval as evaluator |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from safetensors.torch import load_file |
| | from torchtune.modules import RotaryPositionalEmbeddings |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | AutoModelForCausalLM, |
| | PreTrainedModel, |
| | PretrainedConfig, |
| | ) |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| | try: |
| | from flashfftconv import FlashFFTConv |
| |
|
| | flash_fft_available = True |
| | except ImportError as e: |
| | print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.") |
| | flash_fft_available = False |
| |
|
| | try: |
| | from flash_attn import flash_attn_func |
| | except ImportError as e: |
| | print(f"Unable to import Triton-based flash attention: {e}. No alternative currently available.") |
| |
|
| | os.environ["HF_ALLOW_CODE_EVAL"] = "1" |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | loss_fn = nn.CrossEntropyLoss() |
| |
|
| |
|
| | def nearest_power_of_two(n: int, round_up: bool = False) -> int: |
| | if n <= 1: |
| | return 1 |
| | return 1 << ((n - 1).bit_length() if round_up else (n).bit_length() - 1) |
| |
|
| |
|
| | def find_multiple(n: int, k: int) -> int: |
| | if n % k == 0: |
| | return n |
| | return n + k - (n % k) |
| |
|
| |
|
| | def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor: |
| | entries = torch.arange(1, seq_len + 1, dtype=torch.float64) |
| | i_plus_j = entries.reshape(-1, 1) + entries.reshape(1, -1) |
| |
|
| | if use_hankel_L: |
| | sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
| | denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
| | Z = sgn * (8.0 / denom) |
| | elif not use_hankel_L: |
| | Z = 2.0 / (i_plus_j**3 - i_plus_j) |
| | else: |
| | raise ValueError("use_hankel_L must be a boolean") |
| |
|
| | return Z |
| |
|
| |
|
| | def get_spectral_filters( |
| | seq_len: int, |
| | K: int, |
| | use_hankel_L: bool = False, |
| | device: torch.device = None, |
| | dtype: torch.dtype = torch.float64, |
| | ) -> torch.Tensor: |
| | Z = get_hankel(seq_len, use_hankel_L).to(device) |
| | sigma, phi = torch.linalg.eigh(Z) |
| | sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
| | phi_k *= sigma_k**0.25 |
| | return phi_k.to(device=device, dtype=dtype) |
| |
|
| |
|
| | class BaseConfigForCausalLM(PretrainedConfig): |
| | """Base PretrainedConfig class to be decorated with dataclass""" |
| |
|
| | model_type = "base_model" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| |
|
| |
|
| | @dataclass |
| | class FlashSTUConfig(BaseConfigForCausalLM): |
| | model_type = "FlashSTU" |
| |
|
| | |
| | bsz: int = 1 |
| | dim: int = 1024 |
| | r: int = 1024 |
| | num_heads: int = 12 |
| | num_local_heads: Optional[int] = -1 |
| | num_layers: int = 12 |
| | seq_len: int = 4096 |
| | n: int = 8191 |
| | window_size: int = 2048 |
| | vocab_size: int = 200064 |
| | inter_dim: Optional[int] = 3072 |
| | mlp_scale: Optional[float] = 12.0 |
| | weight_tying: Optional[bool] = True |
| | bias: Optional[bool] = False |
| | rope_theta: Optional[float] = 10000.0 |
| | softcap: Optional[float] = 50.0 |
| | num_eigh: Optional[int] = 24 |
| | use_hankel_L: Optional[bool] = False |
| | use_flash_fft: Optional[bool] = True |
| | use_tensordot: Optional[bool] = True |
| | use_attn: Optional[bool] = True |
| | use_alibi: Optional[bool] = False |
| | torch_dtype: torch.dtype = torch.bfloat16 |
| | device: torch.device = None |
| |
|
| | |
| | def __init__( |
| | self, |
| | bsz: int = 1, |
| | dim: int = 1024, |
| | r: int = 1024, |
| | num_heads: int = 12, |
| | num_local_heads: Optional[int] = -1, |
| | num_layers: int = 12, |
| | seq_len: int = 4096, |
| | n: int = 8191, |
| | window_size: int = 2048, |
| | vocab_size: int = 200064, |
| | inter_dim: Optional[int] = 3072, |
| | mlp_scale: Optional[float] = 12.0, |
| | weight_tying: Optional[bool] = True, |
| | bias: Optional[bool] = False, |
| | rope_theta: Optional[float] = 10000.0, |
| | softcap: Optional[float] = 50.0, |
| | num_eigh: Optional[int] = 24, |
| | use_hankel_L: Optional[bool] = False, |
| | use_flash_fft: Optional[bool] = True, |
| | use_tensordot: Optional[bool] = True, |
| | use_attn: Optional[bool] = True, |
| | use_alibi: Optional[bool] = False, |
| | torch_dtype: torch.dtype = torch.bfloat16, |
| | device: torch.device = None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | |
| | self.bsz = bsz |
| | self.dim = dim |
| | self.r = r |
| | self.num_heads = num_heads |
| | self.num_local_heads = num_local_heads |
| | self.num_layers = num_layers |
| | self.seq_len = seq_len |
| | self.n = n |
| | self.window_size = window_size |
| | self.vocab_size = vocab_size |
| | self.inter_dim = inter_dim |
| | self.mlp_scale = mlp_scale |
| | self.weight_tying = weight_tying |
| | self.bias = bias |
| | self.rope_theta = rope_theta |
| | self.softcap = softcap |
| | self.num_eigh = num_eigh |
| | self.use_hankel_L = use_hankel_L |
| | self.use_flash_fft = use_flash_fft |
| | self.use_tensordot = use_tensordot |
| | self.use_attn = use_attn |
| | self.use_alibi = use_alibi |
| | self.torch_dtype = torch_dtype |
| | self.device = device |
| |
|
| | |
| | self.__post_init__() |
| |
|
| | def __post_init__(self): |
| | |
| | if isinstance(self.torch_dtype, str): |
| | try: |
| | self.torch_dtype = getattr(torch, self.torch_dtype) |
| | except AttributeError: |
| | raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") |
| |
|
| | if self.num_local_heads == -1: |
| | self.num_local_heads = self.num_heads |
| | if self.inter_dim is None: |
| | hidden_dim = self.mlp_scale * self.dim |
| | num_hidden = int(2 * hidden_dim / 3) |
| | self.inter_dim = find_multiple(num_hidden, 256) |
| | self.head_dim = self.dim // self.num_heads |
| |
|
| | @classmethod |
| | def from_name(cls, name: str): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | print("Not yet implemented") |
| | pass |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, config: FlashSTUConfig) -> None: |
| | super().__init__() |
| | self.w1 = nn.Linear(config.dim, config.inter_dim) |
| | self.w2 = nn.Linear(config.inter_dim, config.dim) |
| | self.w2.SCALE_INIT = 1 |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.w2(F.gelu(self.w1(x), approximate="tanh")) |
| |
|
| |
|
| | class SlidingWindowAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.wq = nn.Linear(config.dim, config.dim) |
| | self.wk = nn.Linear(config.dim, config.dim) |
| | self.wv = nn.Linear(config.dim, config.dim) |
| | self.wo = nn.Linear(config.dim, config.dim) |
| | self.wo.SCALE_INIT = 1 |
| |
|
| | self.dim = config.dim |
| | self.head_dim = config.head_dim |
| | self.num_heads = config.num_heads |
| | self.num_local_heads = config.num_local_heads |
| | self.window_size = config.window_size |
| | self.softcap = config.softcap |
| |
|
| | self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None |
| | self.rotary_emb = RotaryPositionalEmbeddings( |
| | dim=self.head_dim, |
| | max_seq_len=config.seq_len, |
| | base=config.rope_theta, |
| | ) |
| |
|
| | def forward(self, x): |
| | bsz, seq_len, dim = x.shape |
| |
|
| | q, k, v = self.wq(x), self.wk(x), self.wv(x) |
| | q = q.view(bsz, seq_len, self.num_heads, self.head_dim) |
| | k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
| | v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
| |
|
| | if self.alibi_slopes is None: |
| | q, k = self.rotary_emb(q), self.rotary_emb(k) |
| |
|
| | y = flash_attn_func( |
| | q=q, |
| | k=k, |
| | v=v, |
| | causal=True, |
| | window_size=(self.window_size, 0), |
| | |
| | alibi_slopes=self.alibi_slopes, |
| | ) |
| |
|
| | out = y.reshape(bsz, seq_len, -1) |
| | out = self.wo(out) |
| |
|
| | return out |
| |
|
| | def _generate_slopes(self, n: int): |
| | start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| | return [start * (start**i) for i in range(n)] |
| |
|
| | def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25): |
| | |
| | if math.log2(num_heads).is_integer(): |
| | slopes = self._generate_slopes(num_heads) |
| | else: |
| | |
| | n = nearest_power_of_two(num_heads, round_up=False) |
| | slopes_power_of_two = self._generate_slopes(n) |
| |
|
| | |
| | extra_slopes = self._generate_slopes(2 * n) |
| | extra_slopes_trunc = extra_slopes[0::2][: num_heads - n] |
| | slopes = slopes_power_of_two + extra_slopes_trunc |
| | slopes = torch.tensor(slopes, device=torch.device("cuda")) |
| | slopes = slopes * interpolation_factor |
| | return slopes |
| |
|
| |
|
| | class STU(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| |
|
| | |
| | self.stu_filters = None |
| | self.stu_filters_fft = None |
| |
|
| | self.n = config.n |
| | self.num_eigh = config.num_eigh |
| | self.d_in = config.dim |
| | self.d_out = config.dim |
| | self.r = config.r |
| | self.use_hankel_L = config.use_hankel_L |
| | self.use_tensordot = config.use_tensordot |
| | self.flash_fft = ( |
| | FlashFFTConv(self.n, dtype=torch.bfloat16) if config.use_flash_fft and flash_fft_available else None |
| | ) |
| |
|
| | |
| | if self.use_tensordot: |
| | self.M_inputs = nn.Parameter(torch.zeros(self.d_in, self.d_out)) |
| | self.M_filters = nn.Parameter(torch.zeros(self.num_eigh, self.d_in)) |
| | else: |
| | self.M_phi_plus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) |
| | if not self.use_hankel_L: |
| | self.M_phi_minus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | B, L, D = x.shape |
| |
|
| | if self.use_tensordot: |
| | |
| | x_proj = x @ self.M_inputs |
| | phi_proj = self.stu_filters @ self.M_filters |
| | if self.flash_fft: |
| | spectral_plus, spectral_minus = self.flash_conv(x_proj, phi_proj, self.flash_fft, self.use_tensordot) |
| | else: |
| | spectral_plus, spectral_minus = self.conv(x_proj, phi_proj, self.n, self.use_tensordot) |
| |
|
| | else: |
| | |
| | if self.flash_fft: |
| | U_plus, U_minus = self.flash_conv(x, self.stu_filters, self.flash_fft, self.use_tensordot) |
| | else: |
| | U_plus, U_minus = self.conv(x, self.stu_filters, self.n, self.use_tensordot) |
| |
|
| | B, L, K, D = U_plus.shape |
| | spectral_plus = U_plus.reshape(B, L, K * D) @ self.M_phi_plus.reshape(K * D, self.d_out) |
| | if not self.use_hankel_L: |
| | spectral_minus = U_minus.reshape(B, L, K * D) @ self.M_phi_minus.reshape(K * D, self.d_out) |
| |
|
| | out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus |
| | return out |
| |
|
| | def conv( |
| | self, u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Performs convolution via FFT with causal alignment using a negative featurization. |
| | |
| | The input tensor u is modulated by an alternating sign tensor (sgn) that multiplies every other |
| | time step by -1. This "negative featurization" modulates the phase so that in this implementation |
| | the correct causal output is obtained by simply slicing the first L elements (i.e. [:seq_len]). |
| | Note: Using a conventional slice [seq_len-1:2*seq_len-1] would yield a flipped alignment, resulting in leakage. |
| | |
| | Args: |
| | u: Input tensor of shape (bsz, seq_len, d_in). |
| | v: Kernel tensor; expected shape is (seq_len, d_out) if use_tensordot is True. |
| | n: FFT length (typically set to 2*seq_len - 1 for linear convolution with implicit right zero-padding). |
| | use_tensordot: Boolean flag to control kernel reshaping. |
| | |
| | Returns: |
| | A tuple (U_plus, U_minus) where: |
| | - U_plus is the primary convolution output. |
| | - U_minus is the secondary output, corrected by the sign tensor. |
| | """ |
| | bsz, seq_len, d_in = u.shape |
| |
|
| | sgn = torch.full((1, seq_len, 1), 1, device=u.device) |
| | sgn[:, 1::2] *= -1 |
| |
|
| | if use_tensordot: |
| | _, d_out = v.shape |
| | v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous() |
| | else: |
| | _, K = v.shape |
| | sgn = sgn.unsqueeze(-1) |
| | v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() |
| | u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) |
| |
|
| | |
| | v_fft = torch.fft.rfft(v.to(torch.float32), n=n, dim=1) |
| |
|
| | U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous() |
| | |
| | U_fft = torch.fft.rfft(U.to(torch.float32), n=n, dim=1) |
| |
|
| | |
| | |
| | U_conv = torch.fft.irfft(v_fft * U_fft, n=n, dim=1)[:, :seq_len].to(u.dtype) |
| | U_plus, U_minus = torch.unbind(U_conv, dim=-1) |
| | U_minus = U_minus * sgn |
| |
|
| | return U_plus.type_as(u), U_minus.type_as(u) |
| |
|
| | def flash_conv( |
| | self, |
| | u: torch.Tensor, |
| | v: torch.Tensor, |
| | flash_fft: FlashFFTConv, |
| | use_tensordot: bool = True, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Flash FFT convolution. |
| | |
| | Args: |
| | u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where: |
| | - `B` is the batch size, |
| | - `L` is the sequence length, |
| | - `d_in` is the input dimension. |
| | v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where: |
| | - `K` is the number of filters, |
| | - `d_in` is the input dimension. |
| | flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution. |
| | use_tensordot (bool, optional): If `True`, performs the tensordot approximation (default is `True`). |
| | |
| | Returns: |
| | tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`: |
| | - `U_plus`: Convolved output tensor with positive eigenvalues. |
| | - Shape depends on `use_tensordot`: |
| | - If `use_tensordot=True`: `(B, L, d_in)` |
| | - If `use_tensordot=False`: `(B, L, K, d_in)` |
| | - `U_minus`: Convolved output tensor with negative eigenvalues. |
| | - Shape depends on `use_tensordot`: |
| | - If `use_tensordot=True`: `(B, L, d_in)` |
| | - If `use_tensordot=False`: `(B, L, K, d_in)` |
| | |
| | Raises: |
| | ValueError: If the input tensor shapes do not conform to the expected dimensions. |
| | |
| | Example: |
| | >>> u = torch.randn(4, 16, 32) # (B, L, d_in) |
| | >>> v = torch.randn(8, 32) # (K, d_in) |
| | >>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32) |
| | >>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_tensordot=True) |
| | >>> print(U_plus.shape, U_minus.shape) |
| | torch.Size([4, 16, 32]) torch.Size([4, 16, 32]) |
| | |
| | """ |
| | bsz, seq_len, d_in = u.shape |
| | _, K = v.shape |
| |
|
| | padded_len = nearest_power_of_two(seq_len, round_up=True) |
| | pad_len = padded_len - seq_len |
| |
|
| | sgn = torch.full((1, 1, padded_len), 1, device=u.device) |
| | sgn[:, :, 1::2] = -1 |
| |
|
| | if use_tensordot: |
| | u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16) |
| | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32) |
| | u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) |
| | else: |
| | u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1) |
| | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1) |
| | u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) |
| |
|
| | |
| | U_conv = flash_fft(u_conv.to(torch.bfloat16), v_padded.to(torch.float32)) |
| |
|
| | |
| | U_conv = U_conv[..., :seq_len] |
| | u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) |
| |
|
| | if use_tensordot: |
| | u_minus = u_minus * sgn[:, :, :seq_len] |
| | U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) |
| | else: |
| | sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) |
| | U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() |
| | U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn |
| |
|
| | return U_plus, U_minus |
| |
|
| |
|
| | class SlidingWindowAttentionLayer(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.swa_norm = nn.LayerNorm(config.dim) |
| | self.swa = SlidingWindowAttention(config) |
| | self.mlp_norm = nn.LayerNorm(config.dim) |
| | self.mlp = MLP(config) |
| |
|
| | def forward(self, x): |
| | x = x + self.swa(self.swa_norm(x)) |
| | x = x + self.mlp(self.mlp_norm(x)) |
| | return x |
| |
|
| |
|
| | class STULayer(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.stu_norm = nn.LayerNorm(config.dim) |
| | self.stu = STU(config) |
| | self.mlp_norm = nn.LayerNorm(config.dim) |
| | self.mlp = MLP(config) |
| |
|
| | def forward(self, x): |
| | x = x + self.stu(self.stu_norm(x)) |
| | x = x + self.mlp(self.mlp_norm(x)) |
| | return x |
| |
|
| |
|
| | class FlashSTU(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.tok_emb = nn.Embedding(config.vocab_size, config.dim) |
| | self.layers = nn.ModuleList() |
| |
|
| | for layer_idx in range(config.num_layers): |
| | |
| | if layer_idx % 2 == 0: |
| | self.layers.append(STULayer(config)) |
| | else: |
| | self.layers.append(SlidingWindowAttentionLayer(config)) if config.use_attn else self.layers.append( |
| | STULayer(config) |
| | ) |
| |
|
| | self.norm_f = nn.LayerNorm(config.dim) |
| | self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
| |
|
| | if self.config.weight_tying: |
| | self.tok_emb.weight = self.lm_head.weight |
| |
|
| | self.std = self.config.dim**-0.5 |
| |
|
| | def init_weights(self, module): |
| | std = self.std |
| | if isinstance(module, nn.Linear): |
| | if hasattr(module, "SCALE_INIT"): |
| | std *= (2 * self.config.num_layers) ** -0.5 |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
| |
|
| | def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput: |
| | x = self.tok_emb(input_ids) |
| |
|
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | x = self.norm_f(x) |
| | logits = self.lm_head(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1)) |
| |
|
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | ) |
| |
|
| | def setup_filters( |
| | self, |
| | spectral_filters: torch.Tensor, |
| | spectral_filters_fft: torch.Tensor, |
| | ): |
| | for layer in self.layers: |
| | if isinstance(layer, STULayer): |
| | layer.stu.stu_filters = spectral_filters |
| | layer.stu.stu_filters_fft = spectral_filters_fft |
| |
|
| | def get_num_params(self): |
| | """ |
| | Return the number of parameters in the model. |
| | For non-embedding count (default), the position embeddings get subtracted. |
| | """ |
| | n_params = sum(p.numel() for p in self.parameters()) |
| | return n_params |
| |
|
| |
|
| | def create_base_model_components(model_name_or_path=None, **kwargs): |
| | """Create config and filters needed for model initialization""" |
| | if model_name_or_path is not None: |
| | config = FlashSTUConfig.from_pretrained(model_name_or_path, **kwargs) |
| | else: |
| | config = FlashSTUConfig(**kwargs) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | filters = get_spectral_filters( |
| | seq_len=config.seq_len, |
| | K=config.num_eigh, |
| | use_hankel_L=config.use_hankel_L, |
| | device=device, |
| | dtype=config.torch_dtype, |
| | ) |
| | assert filters.dtype == config.torch_dtype, f"filters dtype is {filters.dtype}, expected {config.torch_dtype}" |
| | return config, filters |
| |
|
| |
|
| | class FlashSTUForCausalLM(PreTrainedModel): |
| | """Thin wrapper to comply with HuggingFace's expected interface""" |
| |
|
| | config_class = FlashSTUConfig |
| | base_model_prefix = "FlashSTU" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.flash_stu = FlashSTU(config) |
| | self.flash_stu.apply(self.flash_stu.init_weights) |
| |
|
| | device = ( |
| | config.device |
| | if config.device is not None |
| | else torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | ) |
| | torch_dtype = config.torch_dtype |
| |
|
| | spectral_filters = get_spectral_filters( |
| | seq_len=config.seq_len, |
| | K=config.num_eigh, |
| | use_hankel_L=config.use_hankel_L, |
| | device=device, |
| | |
| | ) |
| | spectral_filters_fft = torch.fft.rfft(spectral_filters, n=config.n, dim=1) |
| |
|
| | |
| | self.flash_stu.setup_filters( |
| | spectral_filters.to(dtype=torch_dtype), spectral_filters_fft.to(dtype=torch_dtype) |
| | ) |
| | |
| |
|
| | def forward( |
| | self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs |
| | ) -> CausalLMOutput: |
| | outputs = self.flash_stu(input_ids, labels=labels, **kwargs) |
| | return outputs |
| |
|
| | def generate( |
| | self, |
| | input_ids: torch.Tensor, |
| | max_length: int = 32, |
| | num_return_sequences: int = 4, |
| | temperature: float = 0.8, |
| | top_k: int = 50, |
| | top_p: float = 0.95, |
| | repetition_penalty: float = 1.2, |
| | seed: int = 42, |
| | ) -> torch.Tensor: |
| | """Generate text using top-k and nucleus sampling with temperature and repetition penalty. |
| | |
| | Args: |
| | input_ids: Input token ids of shape (batch_size, seq_len) |
| | max_length: Maximum length of generated sequence |
| | num_return_sequences: Number of sequences to generate per input |
| | temperature: Sampling temperature. Higher = more random, lower = more focused |
| | top_k: Number of highest probability tokens to keep for top-k sampling |
| | top_p: Cumulative probability cutoff for nucleus sampling |
| | repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty |
| | seed: Random seed for reproducibility |
| | |
| | Returns: |
| | Generated token ids of shape (num_return_sequences, max_length) |
| | """ |
| | self.eval() |
| | device = input_ids.device |
| |
|
| | |
| | input_ids = input_ids.repeat(num_return_sequences, 1) |
| | generated = input_ids |
| |
|
| | |
| | sample_rng = torch.Generator(device=device) |
| | sample_rng.manual_seed(seed) |
| |
|
| | |
| | with torch.no_grad(): |
| | while generated.size(1) < max_length: |
| | |
| | outputs = self.flash_stu(generated) |
| | next_token_logits = outputs.logits[:, -1, :] |
| |
|
| | |
| | if repetition_penalty != 1.0: |
| | for i in range(generated.shape[0]): |
| | for token in generated[i]: |
| | if token in next_token_logits[i]: |
| | next_token_logits[i, token] /= repetition_penalty |
| |
|
| | |
| | if temperature != 1.0: |
| | next_token_logits = next_token_logits / temperature |
| |
|
| | |
| | probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
| |
|
| | |
| | if top_k > 0: |
| | indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None] |
| | probs[indices_to_remove] = 0 |
| |
|
| | |
| | if top_p < 1.0: |
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
|
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| |
|
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | probs[indices_to_remove] = 0 |
| |
|
| | |
| | probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8) |
| |
|
| | |
| | next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng) |
| |
|
| | |
| | generated = torch.cat([generated, next_token], dim=1) |
| |
|
| | return generated |
| |
|
| | def get_num_params(self): |
| | return self.flash_stu.get_num_params() |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| | |
| | config, _ = create_base_model_components(pretrained_model_name_or_path, **kwargs) |
| | model = cls(config) |
| |
|
| | |
| | weights_path = hf_hub_download( |
| | repo_id=pretrained_model_name_or_path, |
| | filename="model.safetensors", |
| | cache_dir=kwargs.get("cache_dir"), |
| | force_download=kwargs.get("force_download", False), |
| | proxies=kwargs.get("proxies", None), |
| | local_files_only=kwargs.get("local_files_only", False), |
| | use_auth_token=kwargs.get("use_auth_token", None), |
| | revision=kwargs.get("revision", None), |
| | subfolder=kwargs.get("subfolder", ""), |
| | ) |
| |
|
| | state_dict = load_file(weights_path) |
| |
|
| | |
| | tok_emb_key = "tok_emb.weight" |
| | lm_head_key = "lm_head.weight" |
| |
|
| | tok_emb_present = tok_emb_key in state_dict |
| | lm_head_present = lm_head_key in state_dict |
| |
|
| | if tok_emb_present and not lm_head_present: |
| | print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'") |
| | state_dict[lm_head_key] = state_dict[tok_emb_key] |
| | elif lm_head_present and not tok_emb_present: |
| | print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'") |
| | state_dict[tok_emb_key] = state_dict[lm_head_key] |
| | elif not tok_emb_present and not lm_head_present: |
| | |
| | print( |
| | f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed." |
| | ) |
| | |
| |
|
| | |
| | final_state_dict = {f"flash_stu.{k}": v for k, v in state_dict.items()} |
| | model.load_state_dict(final_state_dict) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| |
|
| | |
| | num_params = model.get_num_params() |
| | print(f"\nModel loaded: {pretrained_model_name_or_path}") |
| | print(f"Parameter count: {num_params / 1e6:.2f}M") |
| |
|
| | return model |
| |
|
| |
|
| | |
| | config, filters = create_base_model_components() |
| |
|
| | |
| | AutoConfig.register("FlashSTU", FlashSTUConfig) |
| | AutoModel.register(FlashSTUConfig, FlashSTU) |
| | AutoModelForCausalLM.register(FlashSTUConfig, FlashSTUForCausalLM) |
| |
|
| | print("Registered FlashSTU model and configuration.") |
| |
|
| |
|
| | def run_model_diagnostics(model, tokenizer, device): |
| | """Run detailed diagnostics to analyze model behavior.""" |
| | print("\nRunning model diagnostics...") |
| |
|
| | |
| | test_cases = [ |
| | |
| | "2 + 2 =", |
| | |
| | "The capital of France is Paris. The capital of Germany is", |
| | |
| | "If a train travels 120 kilometers in 2 hours, its average speed is", |
| | |
| | "1, 2, 3, 4,", |
| | |
| | "The following is a detailed explanation of photosynthesis: Plants use sunlight to", |
| | ] |
| |
|
| | with torch.no_grad(): |
| | for prompt in test_cases: |
| | print(f"\nAnalyzing prompt: {prompt}") |
| |
|
| | |
| | tokens = tokenizer(prompt, return_tensors="pt") |
| | input_ids = tokens["input_ids"].to(device) |
| |
|
| | outputs = model.flash_stu(input_ids, labels=input_ids) |
| |
|
| | labels = input_ids.clone() |
| | shift_logits = outputs.logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| |
|
| | loss_fct = nn.CrossEntropyLoss(reduction="none") |
| | token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( |
| | shift_labels.size() |
| | ) |
| |
|
| | |
| | input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
| | print("\nToken-by-token loss:") |
| | for i, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])): |
| | print(f"{token}: {loss.item():.3f}") |
| |
|
| | print(f"Average loss: {token_losses.mean().item():.3f}") |
| |
|
| | |
| | temps = [0.5, 0.7, 1.0] |
| | print("\nGeneration temperature comparison:") |
| | for temp in temps: |
| | gen_ids = model.generate( |
| | input_ids, |
| | max_length=25, |
| | num_return_sequences=1, |
| | temperature=temp, |
| | top_p=0.9, |
| | repetition_penalty=1.5, |
| | seed=42, |
| | ) |
| | gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) |
| | print(f"\nTemp {temp}: {gen_text}") |
| |
|
| |
|
| | def validate_model_generation(): |
| | print("\nRunning generation validation test...") |
| |
|
| | try: |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | |
| | model_id = "Hazan-Lab/FlashSTU-340M-0428" |
| | model = FlashSTUForCausalLM.from_pretrained(model_id) |
| | tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| |
|
| | |
| | num_params = model.get_num_params() |
| | print(f"\nModel loaded: {model_id}") |
| | print(f"Parameter count: {num_params / 1e6:.2f}M") |
| |
|
| | |
| | run_model_diagnostics(model, tokenizer, device) |
| |
|
| | except Exception as e: |
| | print(f"\nError during validation: {str(e)}") |
| | raise |
| |
|
| |
|
| | |
| | tasks = [ |
| | |
| | "hellaswag", |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| |
|
| | tasks_fewshot = { |
| | "hellaswag": 0, |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | } |
| |
|
| | all_results = {} |
| |
|
| | |
| | validate_model_generation() |
| |
|
| | print("\nStarting evaluation tasks...") |
| | for task in tasks: |
| | print(f"\nEvaluating task: {task}") |
| | eval_kwargs = dict( |
| | model="hf", |
| | model_args=( |
| | |
| | "pretrained=Hazan-Lab/FlashSTU-340M-0428," |
| | "trust_remote_code=True," |
| | "dtype=bfloat16," |
| | "cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache" |
| | ), |
| | tasks=[task], |
| | batch_size="auto", |
| | device="cuda:0", |
| | ) |
| | few_shot_value = tasks_fewshot.get(task, -1) |
| | if few_shot_value != -1: |
| | eval_kwargs["num_fewshot"] = few_shot_value |
| | results = evaluator.simple_evaluate(**eval_kwargs) |
| | task_result = results["results"].get(task, {}) |
| | all_results[task] = task_result |
| | print(f"Results for {task}:") |
| | print(task_result) |
| | print("\n" + "=" * 50 + "\n") |
| |
|
| | print("All Evaluation Results:") |
| | for task, result in all_results.items(): |
| | print(f"{task}: {result}") |
| |
|