|
|
""" |
|
|
Credits to https://github.com/karpathy/minGPT |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from .kv_caching import KeysValues, KVCache |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, config: dict) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"] |
|
|
self.drop = nn.Dropout(config["embed_pdrop"]) |
|
|
self.blocks = nn.ModuleList([Block(config) for _ in range(config["num_layers"])]) |
|
|
self.ln_f = nn.LayerNorm(config["embed_dim"]) |
|
|
|
|
|
def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: |
|
|
device = self.ln_f.weight.device |
|
|
return KeysValues(n, self.config["num_heads"], max_tokens, self.config["embed_dim"], self.config["num_layers"], device) |
|
|
|
|
|
def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor: |
|
|
assert past_keys_values is None or len(past_keys_values) == len(self.blocks) |
|
|
x = self.drop(sequences) |
|
|
for i, block in enumerate(self.blocks): |
|
|
x = block(x, None if past_keys_values is None else past_keys_values[i]) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: dict) -> None: |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config["embed_dim"]) |
|
|
self.ln2 = nn.LayerNorm(config["embed_dim"]) |
|
|
self.attn = SelfAttention(config) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(config["embed_dim"], 4 * config["embed_dim"]), |
|
|
nn.GELU(), |
|
|
nn.Linear(4 * config["embed_dim"], config["embed_dim"]), |
|
|
nn.Dropout(config["resid_pdrop"]), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor: |
|
|
x_attn = self.attn(self.ln1(x), past_keys_values) |
|
|
x = x + x_attn |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__(self, config: dict) -> None: |
|
|
super().__init__() |
|
|
assert config["embed_dim"] % config["num_heads"] == 0 |
|
|
assert config["attention"] in ('causal', 'block_causal') |
|
|
self.num_heads = config["num_heads"] |
|
|
self.key = nn.Linear(config["embed_dim"], config["embed_dim"]) |
|
|
self.query = nn.Linear(config["embed_dim"], config["embed_dim"]) |
|
|
self.value = nn.Linear(config["embed_dim"], config["embed_dim"]) |
|
|
self.attn_drop = nn.Dropout(config["attn_pdrop"]) |
|
|
self.resid_drop = nn.Dropout(config["resid_pdrop"]) |
|
|
self.proj = nn.Linear(config["embed_dim"], config["embed_dim"]) |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(config["max_tokens"], config["max_tokens"])) |
|
|
block_causal_mask = torch.max(causal_mask, torch.block_diag(*[torch.ones(config["tokens_per_block"], config["tokens_per_block"]) for _ in range(config["max_blocks"])])) |
|
|
self.register_buffer('mask', causal_mask if config["attention"] == 'causal' else block_causal_mask) |
|
|
|
|
|
def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.Tensor: |
|
|
B, T, C = x.size() |
|
|
if kv_cache is not None: |
|
|
b, nh, L, c = kv_cache.shape |
|
|
assert nh == self.num_heads and b == B and c * nh == C |
|
|
else: |
|
|
L = 0 |
|
|
|
|
|
q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) |
|
|
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) |
|
|
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) |
|
|
|
|
|
if kv_cache is not None: |
|
|
kv_cache.update(k, v) |
|
|
k, v = kv_cache.get() |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
|
att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf')) |
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_drop(att) |
|
|
y = att @ v |
|
|
y = rearrange(y, 'b h t e -> b t (h e)') |
|
|
|
|
|
y = self.resid_drop(self.proj(y)) |
|
|
|
|
|
return y |
|
|
|