Norquinal's picture
Upload model
cd0221e
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import conv1d_cpp
except:
pass
from .utils import column_split
def canonicalize_modal_system(poles, residues):
"""Canonicalize a modal system.
Args:
poles (Tensor): The poles of the system.
residues (Tensor): The residues of the system.
Returns:
Tuple[Tensor, Tensor]: The canonicalized poles and residues.
"""
raise NotImplementedError
IIR_PREFILL_MODES = [
"recurrence",
"modal-fft",
"hybrid-modal-recurrence",
"modal-scan",
"canonical-fft",
"iir-fir-caching",
]
class HyenaInferenceEngine:
def __init__(
self, fir_fn=None, fftconv_fn=None, iir_prefill_style="modal-fft", layer_idx=None
) -> None:
self.fir_fn = fir_fn
self.fftconv_fn = fftconv_fn
assert (
iir_prefill_style in IIR_PREFILL_MODES
), f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
self.iir_prefill_style = iir_prefill_style
self.layer_idx = layer_idx
self.low_mem_mode = False
def parallel_fir(
self,
fir_fn,
u,
weight,
bias,
L,
fir_length=3,
inference_params=None,
prefill_mode=None,
padding_mask=None,
):
"""Compute the output state of the long convolutional filter."""
# prepare input layout, dimensions and dispatch to fir kernel
if fir_fn != torch.nn.functional.conv1d:
z_pre = fir_fn(u)[:, :L] # B, L, D
z_pre = z_pre.permute(0, 2, 1)
else:
u = u.permute(0, 2, 1) # B, D, L
z_pre = fir_fn(
u,
weight,
bias,
stride=1,
padding=fir_length - 1,
groups=u.shape[1],
)[..., :L]
# handle padding post fir, the only place with biases
if type(padding_mask) == torch.Tensor:
z_pre = z_pre * padding_mask[:, None]
if inference_params is not None:
# handle seqlen last and dim last cases for `u`
if fir_fn != torch.nn.functional.conv1d:
fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
else:
fir_state = u[..., -fir_length + 1 :]
else:
fir_state = None
return z_pre, fir_state
def parallel_iir(
self,
z_pre,
h,
D,
L,
poles,
t,
dims,
layer_idx,
inference_params=None,
prefill_style="fft",
fftconv_fn=None,
padding_mask=None,
use_flashfft=False,
column_split_hyena=False,
long_fir_threshold=None,
):
"""Compute the output state of the short convolutional filter."""
fft_size = 2 * L
hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
# Compatibility with training infra that column splits the projections
if column_split_hyena:
z = z_pre.reshape(
z_pre.shape[0],
num_attention_heads,
3 * hidden_size_per_attention_head,
z_pre.shape[2],
)
x2, x1, v = (
z[:, :, :hidden_size_per_attention_head],
z[
:,
:,
hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
],
z[:, :, 2 * hidden_size_per_attention_head :],
)
x2, x1, v = (
x2.reshape(x2.shape[0], -1, x2.shape[-1]),
x1.reshape(x1.shape[0], -1, x1.shape[-1]),
v.reshape(v.shape[0], -1, v.shape[-1]),
)
else:
x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
x1v = x1 * v
if use_flashfft and (L % 2) == 0: # only works with even L
y = fftconv_fn(
x1v.to(dtype=torch.bfloat16).contiguous(),
h.to(dtype=torch.float32),
)
X_s = None
elif long_fir_threshold is None:
H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
X = X_s[..., : H.shape[-1]]
if len(z_pre.shape) > 3:
H = H.unsqueeze(1)
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
else:
assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
h = h[..., :long_fir_threshold]
y = F.conv1d(
x1v,
h.to(dtype=x1v.dtype),
stride=1,
groups=x1v.shape[1],
padding=h.shape[-1] - 1,
)[..., :L]
y = y.to(dtype=x1v.dtype)
y = (y + x1v * D.unsqueeze(-1)) * x2
if inference_params is not None:
if prefill_style == "fft":
self.prefill_via_modal_fft(
inference_params=inference_params,
x1v=x1v,
X_s=X_s,
L=L,
t=t,
poles=poles,
dims=dims,
layer_idx=layer_idx,
use_flashfft=use_flashfft,
)
elif prefill_style == "recurrence":
self.prefill_via_direct_recurrence(
inference_params=inference_params,
x1v=x1v,
L=L,
poles=poles,
)
else:
raise NotImplementedError
if self.low_mem_mode:
del z_pre, x2, x1, v, x1v, h
torch.cuda.empty_cache()
return y.permute(0, 2, 1)
def step_fir(self, u, fir_state, weight, bias=None):
"""Step the FIR filter.
Note:
`fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
"""
h0, h = weight[..., 0, -1], weight[..., 0, :-1]
h0, h = h0[None], h[None]
y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
# update
fir_state = torch.roll(fir_state, -1, dims=2)
fir_state[..., -1] = u
return y, fir_state
def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
x1v = x1 * v
residues, poles = (
torch.view_as_complex(residues.to(torch.float32)),
torch.view_as_complex(poles.to(torch.float32)),
)
# squeeze the dummy seqlen dimension
# D, state_dim, 1 -> 1, D, state_dim
residues, poles = residues[..., 0][None], poles[..., 0][None]
iir_state = poles * iir_state + x1v[..., None]
res_state = torch.sum(residues * iir_state, dim=-1).real
if iir_groups > 1:
raise NotImplementedError
y = x2 * (res_state + D * x1v)
return y, iir_state
def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
"""Turns the IIR filter into a FIR and uses a cache for decoding."""
raise NotImplementedError(":)")
def prefill_via_direct_recurrence(self, inference_params, x1v, L, poles, *args, **kwargs):
"""
Compute the IIR state via explicit SSM recurrence (modal form)
"""
x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
x1v_ = x1v_.repeat(1, 1, 1, 1, 2) # b, d, l, sdim, reim
state = x1v_[:, :, 0]
poles = poles[:, :, 0].to(dtype=torch.float32)
for i in range(L):
state = poles * state + x1v_[:, :, i]
inference_params.state_dict[self.layer_idx] = torch.view_as_complex(
state.to(dtype=torch.float32)
)
def prefill_via_hybrid_recurrence(
self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs
):
"""
Compute the IIR state via hybrid recurrence-convolution over blocks
"""
raise NotImplementedError(":)")
def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
raise NotImplementedError
def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
"""
Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
This is the most memory efficient "parallelized" prefilling method for Hyena.
From: https://arxiv.org/abs/2310.18780
"""
raise NotImplementedError(":)")
def prefill_via_modal_fft(
self,
inference_params,
x1v,
L,
poles,
t,
dims,
layer_idx,
X_s=None,
use_flashfft=False,
state_dtype=torch.complex64,
*args,
**kwargs,
):
"""
Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
"""
# When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
# we split the filter into poles and residues and reuse FFT computation on the input.
# This optimization is currently not supported when using flashfftconv.
hidden_size, _, _, state_size, hyena_filter_groups = dims
if use_flashfft:
# using real states
poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
state_s = poles**t
if hyena_filter_groups > 1:
raise NotImplementedError
x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
state_s = state_s[None]
state = self.fftconv_fn(
x1v.contiguous(),
state_s.to(dtype=torch.float32),
)
state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
state = torch.view_as_complex(state.contiguous())
inference_params.state_dict[self.layer_idx] = state.to(dtype=state_dtype)
else:
assert X_s is not None
bs = x1v.shape[0]
fft_size = 2 * L
poles = torch.view_as_complex(poles.to(torch.float32))
state_s = poles**t
state_S = torch.fft.fft(state_s, n=fft_size).repeat(
bs, 1, 1, 1
) # B, D, state_dim, 2 * L
if hyena_filter_groups > 1:
state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
"""
Compute the IIR state given an input `u` and log_poles of the modal system.
"""
bs = u.shape[0]
fft_size = 2 * L
U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
fft_size = 2 * L
x = (log_poles * t).exp()
# [batch, hidden_size, state_dim, 2 * seqlen]
X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
return state