| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import os
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, List, Tuple, Union
|
| | import math
|
| | import torch.utils.checkpoint
|
| | from transformers import PreTrainedModel, PretrainedConfig
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast
|
| |
|
| | class TernaryConfig(PretrainedConfig):
|
| | model_type = "ternary_transformer"
|
| | def __init__(
|
| | self,
|
| | vocab_size=50257,
|
| | hidden_size=3072,
|
| | num_hidden_layers=24,
|
| | num_attention_heads=32,
|
| | intermediate_size=12288,
|
| | max_position_embeddings=2048,
|
| | rms_norm_eps=1e-6,
|
| | dropout_rate=0.1,
|
| | window_size=512,
|
| | **kwargs
|
| | ):
|
| | super().__init__(**kwargs)
|
| | self.vocab_size = vocab_size
|
| | self.hidden_size = hidden_size
|
| | self.num_hidden_layers = num_hidden_layers
|
| | self.num_attention_heads = num_attention_heads
|
| | self.intermediate_size = intermediate_size
|
| | self.max_position_embeddings = max_position_embeddings
|
| | self.rms_norm_eps = rms_norm_eps
|
| | self.dropout_rate = dropout_rate
|
| | self.window_size = window_size
|
| |
|
| | class BitLinear(nn.Linear):
|
| | def __init__(self, in_features, out_features, bias=False, num_layers=24):
|
| | super().__init__(in_features, out_features, bias)
|
| | std = 0.02 / math.sqrt(2 * num_layers)
|
| | nn.init.normal_(self.weight, mean=0.0, std=std)
|
| |
|
| | def forward(self, x):
|
| | w = self.weight
|
| | gamma = w.abs().mean() + 1e-9
|
| | w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
|
| | w_final = w + (w_quant * gamma - w).detach()
|
| | x_norm = x - x.mean(dim=-1, keepdim=True)
|
| | x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
|
| | return F.linear(x_quant, w_final, self.bias)
|
| |
|
| | class RMSNorm(nn.Module):
|
| | def __init__(self, dim: int, eps: float = 1e-6):
|
| | super().__init__()
|
| | self.eps = eps
|
| | self.weight = nn.Parameter(torch.ones(dim))
|
| | def forward(self, x):
|
| | norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| | return norm * self.weight
|
| |
|
| | def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
|
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| | t = torch.arange(seq_len).float()
|
| | freqs = torch.outer(t, freqs)
|
| | return torch.polar(torch.ones_like(freqs), freqs)
|
| |
|
| | def apply_rotary_emb(xq, xk, freqs_cis):
|
| | xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| | xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| | freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
|
| | xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
|
| | xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
|
| | return xq_out.type_as(xq), xk_out.type_as(xk)
|
| |
|
| | class MultiHeadAttention(nn.Module):
|
| | def __init__(self, config: TernaryConfig):
|
| | super().__init__()
|
| | self.n_heads = config.num_attention_heads
|
| | self.head_dim = config.hidden_size // config.num_attention_heads
|
| | self.q_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| | self.k_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| | self.v_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| | self.out_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| | self.scale = self.head_dim ** -0.5
|
| | self.window_size = config.window_size
|
| |
|
| | def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| | B, T, D = x.shape
|
| | q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
|
| | if past_kv is not None:
|
| | pk, pv = past_kv
|
| | k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
|
| | v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
|
| | new_kv = (k.detach(), v.detach())
|
| | attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
|
| | mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
|
| | attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
|
| | out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
|
| | return self.out_proj(out), new_kv
|
| |
|
| | class SwiGLUFeedForward(nn.Module):
|
| | def __init__(self, config: TernaryConfig):
|
| | super().__init__()
|
| | self.w1 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
| | self.w3 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
| | self.w2 = BitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| | def forward(self, x):
|
| | return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| |
|
| | class TransformerBlock(nn.Module):
|
| | def __init__(self, config: TernaryConfig):
|
| | super().__init__()
|
| | self.attn = MultiHeadAttention(config)
|
| | self.ffn = SwiGLUFeedForward(config)
|
| | self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| | self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| | self.dropout = nn.Dropout(config.dropout_rate)
|
| | def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| | h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
|
| | x = x + self.dropout(h)
|
| | x = x + self.dropout(self.ffn(self.norm2(x)))
|
| | return x, new_kv
|
| |
|
| | class TernaryTransformer(PreTrainedModel):
|
| | config_class = TernaryConfig
|
| | supports_gradient_checkpointing = True
|
| | def __init__(self, config: TernaryConfig):
|
| | super().__init__(config)
|
| | self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
|
| | self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
|
| | self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| | self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
|
| | self.post_init()
|
| | self.lm_head.weight = self.token_emb.weight
|
| | self.gradient_checkpointing = False
|
| |
|
| | def _set_gradient_checkpointing(self, module, value=False):
|
| | if isinstance(module, (TernaryTransformer, TransformerBlock)):
|
| | self.gradient_checkpointing = value
|
| |
|
| | def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
|
| | x = self.token_emb(input_ids)
|
| | pos_offset = past_key_values[0][0].size(2) if past_key_values and past_key_values[0] is not None else 0
|
| | new_kvs = []
|
| | for i, block in enumerate(self.blocks):
|
| | if self.gradient_checkpointing and self.training:
|
| | x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
|
| | else:
|
| | x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
|
| | if not self.training or past_key_values: new_kvs.append(kv)
|
| | logits = self.lm_head(self.ln_f(x))
|
| | loss = None
|
| | if labels is not None:
|
| | loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
|
| | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None) |