import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange from rotary_embedding_torch import RotaryEmbedding from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput import torch.utils.checkpoint from torch import nn, Tensor from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from typing import Optional, Tuple, Union, Any # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def padding_to_multiple_of(n, mult): remainder = n % mult if remainder == 0: return 0 return mult - remainder # scalenorm class ScaleNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) def forward(self, x): norm = torch.norm(x, dim = -1, keepdim = True) * self.scale return x / norm.clamp(min = self.eps) * self.g # absolute positional encodings class ScaledSinuEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.scale = nn.Parameter(torch.ones(1,)) inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x): n, device = x.shape[1], x.device t = torch.arange(n, device = device).type_as(self.inv_freq) sinu = einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) return emb * self.scale # T5 relative positional bias class T5RelativePositionBias(nn.Module): def __init__( self, scale, causal = False, num_buckets = 32, max_distance = 128 ): super().__init__() self.scale = scale self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, 1) @staticmethod def _relative_position_bucket( relative_position, causal = True, num_buckets = 32, max_distance = 128 ): ret = 0 n = -relative_position if not causal: num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, x): i, j, device = *x.shape[-2:], x.device q_pos = torch.arange(i, dtype = torch.long, device = device) k_pos = torch.arange(j, dtype = torch.long, device = device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j 1 -> i j') return bias * self.scale # class class OffsetScale(nn.Module): def __init__(self, dim, heads = 1): super().__init__() self.gamma = nn.Parameter(torch.ones(heads, dim)) self.beta = nn.Parameter(torch.zeros(heads, dim)) nn.init.normal_(self.gamma, std = 0.02) def forward(self, x): out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta return out.unbind(dim = -2) # activation functions class ReLUSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 class LaplacianAttnFn(nn.Module): """ https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """ def forward(self, x): mu = math.sqrt(0.5) std = math.sqrt((4 * math.pi) ** -1) return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5 class FLASH(nn.Module): def __init__( self, *, dim, group_size = 256, query_key_dim = 128, expansion_factor = 2., causal = False, dropout = 0., rotary_pos_emb = None, norm_klass = nn.LayerNorm, shift_tokens = False, laplace_attn_fn = False, reduce_group_non_causal_attn = True ): super().__init__() hidden_dim = int(dim * expansion_factor) self.group_size = group_size self.causal = causal self.shift_tokens = shift_tokens self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn() # positional embeddings self.rotary_pos_emb = rotary_pos_emb self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal) # norm self.norm = norm_klass(dim) self.dropout = nn.Dropout(dropout) # whether to reduce groups in non causal linear attention self.reduce_group_non_causal_attn = reduce_group_non_causal_attn # projections self.to_hidden = nn.Sequential( nn.Linear(dim, hidden_dim * 2), nn.SiLU() ) self.to_qk = nn.Sequential( nn.Linear(dim, query_key_dim), nn.SiLU() ) self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) self.to_out = nn.Linear(hidden_dim, dim) def forward( self, x, *, mask = None ): """ b - batch n - sequence length (within groups) g - group dimension d - feature dimension (keys) e - feature dimension (values) i - sequence dimension (source) j - sequence dimension (target) """ b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size # prenorm normed_x = self.norm(x) # do token shift - a great, costless trick from an independent AI researcher in Shenzhen if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1) # initial projections v, gate = self.to_hidden(normed_x).chunk(2, dim = -1) qk = self.to_qk(normed_x) # offset and scale quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) # mask out linear attention keys if exists(mask): lin_mask = rearrange(mask, '... -> ... 1') lin_k = lin_k.masked_fill(~lin_mask.bool(), 0.) # rotate queries and keys if exists(self.rotary_pos_emb): quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)) # padding for groups padding = padding_to_multiple_of(n, g) if padding > 0: quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v)) mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) mask = F.pad(mask, (0, padding), value = False) # group along sequence quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v)) if exists(mask): mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) # calculate quadratic attention output sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g sim = sim + self.rel_pos_bias(sim) attn = self.attn_fn(sim) attn = self.dropout(attn) if exists(mask): attn = attn.masked_fill(~mask.bool(), 0.) if self.causal: causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) attn = attn.masked_fill(causal_mask.bool(), 0.) quad_out = einsum('... i j, ... j d -> ... i d', attn, v) # calculate linear attention output if self.causal: lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g # exclusive cumulative sum along group dimension lin_kv = lin_kv.cumsum(dim = 1) lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.) lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) else: context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e' lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv) # fold back groups into full sequence, and excise out padding quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out)) # gate out = gate * (quad_attn_out + lin_attn_out) # projection out and residual return self.to_out(out) + x # FLASH Transformer class FLASHTransformer(nn.Module): def __init__( self, *, dim, num_tokens, depth, group_size = 256, query_key_dim = 128, expansion_factor = 2., causal = False, attn_dropout = 0., norm_type = 'scalenorm', shift_tokens = True, laplace_attn_fn = False, reduce_group_non_causal_attn = True ): super().__init__() assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm' if norm_type == 'scalenorm': norm_klass = ScaleNorm elif norm_type == 'layernorm': norm_klass = nn.LayerNorm self.token_emb = nn.Embedding(num_tokens, dim) self.abs_pos_emb = ScaledSinuEmbedding(dim) self.group_size = group_size rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim)) # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)]) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_tokens) ) def forward( self, x, *, mask = None ): x = self.token_emb(x) x = self.abs_pos_emb(x) + x for flash in self.layers: x = flash(x, mask = mask) x_norm = self.to_logits[0](x) logits = self.to_logits[1](x_norm) return logits, x_norm class FLASHTransformerConfig(PretrainedConfig): model_type = "flash_transformer" def __init__( self, hidden_size=512, vocab_size=4096, num_layers=12, group_size=256, query_key_dim=128, expansion_factor=2.0, causal=False, attn_dropout=0.1, norm_type="scalenorm", shift_tokens=True, laplace_attn_fn=False, reduce_group_non_causal_attn=True, **kwargs ): super().__init__(**kwargs) self.hidden_size = hidden_size self.vocab_size = vocab_size self.num_layers = num_layers self.group_size = group_size self.query_key_dim = query_key_dim self.expansion_factor = expansion_factor self.causal = causal self.attn_dropout = attn_dropout self.norm_type = norm_type self.shift_tokens = shift_tokens self.laplace_attn_fn = laplace_attn_fn self.reduce_group_non_causal_attn = reduce_group_non_causal_attn class FLASHTransformerForPretrained(PreTrainedModel): config_class = FLASHTransformerConfig base_model_prefix = "flash_transformer" def __init__(self, config): super().__init__(config) self.model = FLASHTransformer( dim=config.hidden_size, num_tokens=config.vocab_size, depth=config.num_layers, group_size=config.group_size, query_key_dim=config.query_key_dim, expansion_factor=config.expansion_factor, causal=config.causal, attn_dropout=config.attn_dropout, norm_type=config.norm_type, shift_tokens=config.shift_tokens, laplace_attn_fn=config.laplace_attn_fn, reduce_group_non_causal_attn=config.reduce_group_non_causal_attn ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None )->Union[Tuple, MaskedLMOutput]: logits, x = self.model(input_ids, mask=attention_mask) return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None) class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) if getattr(config, "use_mlp_classifier", False): self.score = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.Dropout(0.1), nn.Linear(config.hidden_size, self.num_labels, bias=False), ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ # 获取基模型输出 outputs = super().forward( input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs["hidden_states"] input_mask_expanded = attention_mask.unsqueeze(-1) # 维度匹配 mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均 logits = self.score(mean_pooled) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput(loss=loss, logits=logits)