|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
|
|
from einops import rearrange |
|
from rotary_embedding_torch import RotaryEmbedding |
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput |
|
|
|
import torch.utils.checkpoint |
|
from torch import nn, Tensor |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from typing import Optional, Tuple, Union, Any |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
def padding_to_multiple_of(n, mult): |
|
remainder = n % mult |
|
if remainder == 0: |
|
return 0 |
|
return mult - remainder |
|
|
|
|
|
|
|
class ScaleNorm(nn.Module): |
|
def __init__(self, dim, eps = 1e-5): |
|
super().__init__() |
|
self.scale = dim ** -0.5 |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(1)) |
|
|
|
def forward(self, x): |
|
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale |
|
return x / norm.clamp(min = self.eps) * self.g |
|
|
|
|
|
|
|
class ScaledSinuEmbedding(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.scale = nn.Parameter(torch.ones(1,)) |
|
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer('inv_freq', inv_freq) |
|
|
|
def forward(self, x): |
|
n, device = x.shape[1], x.device |
|
t = torch.arange(n, device = device).type_as(self.inv_freq) |
|
sinu = einsum('i , j -> i j', t, self.inv_freq) |
|
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) |
|
return emb * self.scale |
|
|
|
|
|
|
|
class T5RelativePositionBias(nn.Module): |
|
def __init__( |
|
self, |
|
scale, |
|
causal = False, |
|
num_buckets = 32, |
|
max_distance = 128 |
|
): |
|
super().__init__() |
|
self.scale = scale |
|
self.causal = causal |
|
self.num_buckets = num_buckets |
|
self.max_distance = max_distance |
|
self.relative_attention_bias = nn.Embedding(num_buckets, 1) |
|
|
|
@staticmethod |
|
def _relative_position_bucket( |
|
relative_position, |
|
causal = True, |
|
num_buckets = 32, |
|
max_distance = 128 |
|
): |
|
ret = 0 |
|
n = -relative_position |
|
if not causal: |
|
num_buckets //= 2 |
|
ret += (n < 0).long() * num_buckets |
|
n = torch.abs(n) |
|
else: |
|
n = torch.max(n, torch.zeros_like(n)) |
|
|
|
max_exact = num_buckets // 2 |
|
is_small = n < max_exact |
|
|
|
val_if_large = max_exact + ( |
|
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) |
|
).long() |
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) |
|
|
|
ret += torch.where(is_small, n, val_if_large) |
|
return ret |
|
|
|
def forward(self, x): |
|
i, j, device = *x.shape[-2:], x.device |
|
q_pos = torch.arange(i, dtype = torch.long, device = device) |
|
k_pos = torch.arange(j, dtype = torch.long, device = device) |
|
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') |
|
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) |
|
values = self.relative_attention_bias(rp_bucket) |
|
bias = rearrange(values, 'i j 1 -> i j') |
|
return bias * self.scale |
|
|
|
|
|
|
|
class OffsetScale(nn.Module): |
|
def __init__(self, dim, heads = 1): |
|
super().__init__() |
|
self.gamma = nn.Parameter(torch.ones(heads, dim)) |
|
self.beta = nn.Parameter(torch.zeros(heads, dim)) |
|
nn.init.normal_(self.gamma, std = 0.02) |
|
|
|
def forward(self, x): |
|
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta |
|
return out.unbind(dim = -2) |
|
|
|
|
|
|
|
class ReLUSquared(nn.Module): |
|
def forward(self, x): |
|
return F.relu(x) ** 2 |
|
|
|
class LaplacianAttnFn(nn.Module): |
|
""" https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """ |
|
|
|
def forward(self, x): |
|
mu = math.sqrt(0.5) |
|
std = math.sqrt((4 * math.pi) ** -1) |
|
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5 |
|
|
|
|
|
class FLASH(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
group_size = 256, |
|
query_key_dim = 128, |
|
expansion_factor = 2., |
|
causal = False, |
|
dropout = 0., |
|
rotary_pos_emb = None, |
|
norm_klass = nn.LayerNorm, |
|
shift_tokens = False, |
|
laplace_attn_fn = False, |
|
reduce_group_non_causal_attn = True |
|
): |
|
super().__init__() |
|
hidden_dim = int(dim * expansion_factor) |
|
self.group_size = group_size |
|
self.causal = causal |
|
self.shift_tokens = shift_tokens |
|
|
|
self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn() |
|
|
|
|
|
|
|
self.rotary_pos_emb = rotary_pos_emb |
|
self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal) |
|
|
|
|
|
|
|
self.norm = norm_klass(dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn |
|
|
|
|
|
|
|
self.to_hidden = nn.Sequential( |
|
nn.Linear(dim, hidden_dim * 2), |
|
nn.SiLU() |
|
) |
|
|
|
self.to_qk = nn.Sequential( |
|
nn.Linear(dim, query_key_dim), |
|
nn.SiLU() |
|
) |
|
|
|
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) |
|
self.to_out = nn.Linear(hidden_dim, dim) |
|
|
|
def forward( |
|
self, |
|
x, |
|
*, |
|
mask = None |
|
): |
|
""" |
|
b - batch |
|
n - sequence length (within groups) |
|
g - group dimension |
|
d - feature dimension (keys) |
|
e - feature dimension (values) |
|
i - sequence dimension (source) |
|
j - sequence dimension (target) |
|
""" |
|
|
|
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size |
|
|
|
|
|
|
|
normed_x = self.norm(x) |
|
|
|
|
|
|
|
if self.shift_tokens: |
|
x_shift, x_pass = normed_x.chunk(2, dim = -1) |
|
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) |
|
normed_x = torch.cat((x_shift, x_pass), dim = -1) |
|
|
|
|
|
|
|
v, gate = self.to_hidden(normed_x).chunk(2, dim = -1) |
|
qk = self.to_qk(normed_x) |
|
|
|
|
|
|
|
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) |
|
|
|
|
|
|
|
if exists(mask): |
|
lin_mask = rearrange(mask, '... -> ... 1') |
|
lin_k = lin_k.masked_fill(~lin_mask.bool(), 0.) |
|
|
|
|
|
|
|
if exists(self.rotary_pos_emb): |
|
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)) |
|
|
|
|
|
|
|
padding = padding_to_multiple_of(n, g) |
|
|
|
if padding > 0: |
|
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v)) |
|
|
|
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) |
|
mask = F.pad(mask, (0, padding), value = False) |
|
|
|
|
|
|
|
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v)) |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) |
|
|
|
|
|
|
|
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g |
|
|
|
sim = sim + self.rel_pos_bias(sim) |
|
|
|
attn = self.attn_fn(sim) |
|
attn = self.dropout(attn) |
|
|
|
if exists(mask): |
|
attn = attn.masked_fill(~mask.bool(), 0.) |
|
|
|
if self.causal: |
|
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) |
|
attn = attn.masked_fill(causal_mask.bool(), 0.) |
|
|
|
quad_out = einsum('... i j, ... j d -> ... i d', attn, v) |
|
|
|
|
|
|
|
if self.causal: |
|
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g |
|
|
|
|
|
|
|
lin_kv = lin_kv.cumsum(dim = 1) |
|
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.) |
|
|
|
lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) |
|
else: |
|
context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e' |
|
lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n |
|
lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv) |
|
|
|
|
|
|
|
quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out)) |
|
|
|
|
|
|
|
out = gate * (quad_attn_out + lin_attn_out) |
|
|
|
|
|
|
|
return self.to_out(out) + x |
|
|
|
|
|
|
|
class FLASHTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
num_tokens, |
|
depth, |
|
group_size = 256, |
|
query_key_dim = 128, |
|
expansion_factor = 2., |
|
causal = False, |
|
attn_dropout = 0., |
|
norm_type = 'scalenorm', |
|
shift_tokens = True, |
|
laplace_attn_fn = False, |
|
reduce_group_non_causal_attn = True |
|
): |
|
super().__init__() |
|
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm' |
|
|
|
if norm_type == 'scalenorm': |
|
norm_klass = ScaleNorm |
|
elif norm_type == 'layernorm': |
|
norm_klass = nn.LayerNorm |
|
|
|
self.token_emb = nn.Embedding(num_tokens, dim) |
|
self.abs_pos_emb = ScaledSinuEmbedding(dim) |
|
self.group_size = group_size |
|
|
|
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim)) |
|
|
|
|
|
self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)]) |
|
|
|
self.to_logits = nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, num_tokens) |
|
) |
|
|
|
def forward( |
|
self, |
|
x, |
|
*, |
|
mask = None |
|
): |
|
x = self.token_emb(x) |
|
x = self.abs_pos_emb(x) + x |
|
|
|
for flash in self.layers: |
|
x = flash(x, mask = mask) |
|
x_norm = self.to_logits[0](x) |
|
logits = self.to_logits[1](x_norm) |
|
return logits, x_norm |
|
|
|
class FLASHTransformerConfig(PretrainedConfig): |
|
model_type = "flash_transformer" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=512, |
|
vocab_size=4096, |
|
num_layers=12, |
|
group_size=256, |
|
query_key_dim=128, |
|
expansion_factor=2.0, |
|
causal=False, |
|
attn_dropout=0.1, |
|
norm_type="scalenorm", |
|
shift_tokens=True, |
|
laplace_attn_fn=False, |
|
reduce_group_non_causal_attn=True, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.hidden_size = hidden_size |
|
self.vocab_size = vocab_size |
|
self.num_layers = num_layers |
|
self.group_size = group_size |
|
self.query_key_dim = query_key_dim |
|
self.expansion_factor = expansion_factor |
|
self.causal = causal |
|
self.attn_dropout = attn_dropout |
|
self.norm_type = norm_type |
|
self.shift_tokens = shift_tokens |
|
self.laplace_attn_fn = laplace_attn_fn |
|
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn |
|
|
|
|
|
class FLASHTransformerForPretrained(PreTrainedModel): |
|
config_class = FLASHTransformerConfig |
|
base_model_prefix = "flash_transformer" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = FLASHTransformer( |
|
dim=config.hidden_size, |
|
num_tokens=config.vocab_size, |
|
depth=config.num_layers, |
|
group_size=config.group_size, |
|
query_key_dim=config.query_key_dim, |
|
expansion_factor=config.expansion_factor, |
|
causal=config.causal, |
|
attn_dropout=config.attn_dropout, |
|
norm_type=config.norm_type, |
|
shift_tokens=config.shift_tokens, |
|
laplace_attn_fn=config.laplace_attn_fn, |
|
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None |
|
)->Union[Tuple, MaskedLMOutput]: |
|
logits, x = self.model(input_ids, mask=attention_mask) |
|
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None) |
|
|
|
class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
if getattr(config, "use_mlp_classifier", False): |
|
self.score = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
nn.GELU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(config.hidden_size, self.num_labels, bias=False), |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
|
|
outputs = super().forward( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = outputs["hidden_states"] |
|
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) |
|
mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) |
|
logits = self.score(mean_pooled) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
|
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and ( |
|
labels.dtype == torch.long or labels.dtype == torch.int |
|
): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
if not return_dict: |
|
output = (logits,) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput(loss=loss, logits=logits) |