DNAFlash / dnaflash.py
wangleiofficial's picture
Update dnaflash.py
9b747e5 verified
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
import torch.utils.checkpoint
from torch import nn, Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union, Any
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
# scalenorm
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# absolute positional encodings
class ScaledSinuEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1,))
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device = device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
return emb * self.scale
# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale,
causal = False,
num_buckets = 32,
max_distance = 128
):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, 1)
@staticmethod
def _relative_position_bucket(
relative_position,
causal = True,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, x):
i, j, device = *x.shape[-2:], x.device
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j 1 -> i j')
return bias * self.scale
# class
class OffsetScale(nn.Module):
def __init__(self, dim, heads = 1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim = -2)
# activation functions
class ReLUSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
class LaplacianAttnFn(nn.Module):
""" https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """
def forward(self, x):
mu = math.sqrt(0.5)
std = math.sqrt((4 * math.pi) ** -1)
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
class FLASH(nn.Module):
def __init__(
self,
*,
dim,
group_size = 256,
query_key_dim = 128,
expansion_factor = 2.,
causal = False,
dropout = 0.,
rotary_pos_emb = None,
norm_klass = nn.LayerNorm,
shift_tokens = False,
laplace_attn_fn = False,
reduce_group_non_causal_attn = True
):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn()
# positional embeddings
self.rotary_pos_emb = rotary_pos_emb
self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal)
# norm
self.norm = norm_klass(dim)
self.dropout = nn.Dropout(dropout)
# whether to reduce groups in non causal linear attention
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
# projections
self.to_hidden = nn.Sequential(
nn.Linear(dim, hidden_dim * 2),
nn.SiLU()
)
self.to_qk = nn.Sequential(
nn.Linear(dim, query_key_dim),
nn.SiLU()
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(
self,
x,
*,
mask = None
):
"""
b - batch
n - sequence length (within groups)
g - group dimension
d - feature dimension (keys)
e - feature dimension (values)
i - sequence dimension (source)
j - sequence dimension (target)
"""
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
# prenorm
normed_x = self.norm(x)
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)
# initial projections
v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
# mask out linear attention keys
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask.bool(), 0.)
# rotate queries and keys
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v))
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
mask = F.pad(mask, (0, padding), value = False)
# group along sequence
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
# calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
sim = sim + self.rel_pos_bias(sim)
attn = self.attn_fn(sim)
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask.bool(), 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
attn = attn.masked_fill(causal_mask.bool(), 0.)
quad_out = einsum('... i j, ... j d -> ... i d', attn, v)
# calculate linear attention output
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# exclusive cumulative sum along group dimension
lin_kv = lin_kv.cumsum(dim = 1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
else:
context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e'
lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n
lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv)
# fold back groups into full sequence, and excise out padding
quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out))
# gate
out = gate * (quad_attn_out + lin_attn_out)
# projection out and residual
return self.to_out(out) + x
# FLASH Transformer
class FLASHTransformer(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
group_size = 256,
query_key_dim = 128,
expansion_factor = 2.,
causal = False,
attn_dropout = 0.,
norm_type = 'scalenorm',
shift_tokens = True,
laplace_attn_fn = False,
reduce_group_non_causal_attn = True
):
super().__init__()
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
self.token_emb = nn.Embedding(num_tokens, dim)
self.abs_pos_emb = ScaledSinuEmbedding(dim)
self.group_size = group_size
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)])
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
)
def forward(
self,
x,
*,
mask = None
):
x = self.token_emb(x)
x = self.abs_pos_emb(x) + x
for flash in self.layers:
x = flash(x, mask = mask)
x_norm = self.to_logits[0](x)
logits = self.to_logits[1](x_norm)
return logits, x_norm
class FLASHTransformerConfig(PretrainedConfig):
model_type = "flash_transformer"
def __init__(
self,
hidden_size=512,
vocab_size=4096,
num_layers=12,
group_size=256,
query_key_dim=128,
expansion_factor=2.0,
causal=False,
attn_dropout=0.1,
norm_type="scalenorm",
shift_tokens=True,
laplace_attn_fn=False,
reduce_group_non_causal_attn=True,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_layers = num_layers
self.group_size = group_size
self.query_key_dim = query_key_dim
self.expansion_factor = expansion_factor
self.causal = causal
self.attn_dropout = attn_dropout
self.norm_type = norm_type
self.shift_tokens = shift_tokens
self.laplace_attn_fn = laplace_attn_fn
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
class FLASHTransformerForPretrained(PreTrainedModel):
config_class = FLASHTransformerConfig
base_model_prefix = "flash_transformer"
def __init__(self, config):
super().__init__(config)
self.model = FLASHTransformer(
dim=config.hidden_size,
num_tokens=config.vocab_size,
depth=config.num_layers,
group_size=config.group_size,
query_key_dim=config.query_key_dim,
expansion_factor=config.expansion_factor,
causal=config.causal,
attn_dropout=config.attn_dropout,
norm_type=config.norm_type,
shift_tokens=config.shift_tokens,
laplace_attn_fn=config.laplace_attn_fn,
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
)->Union[Tuple, MaskedLMOutput]:
logits, x = self.model(input_ids, mask=attention_mask)
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
if getattr(config, "use_mlp_classifier", False):
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.hidden_size, self.num_labels, bias=False),
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 获取基模型输出
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs["hidden_states"]
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
logits = self.score(mean_pooled)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(loss=loss, logits=logits)