Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
raw
history blame
No virus
7.22 kB
"""
OpenAI's GPT-2 ported to PyTorch.
"""
import math
import attr
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.checkpoint
@attr.s(auto_attribs=True, frozen=True)
class HParams:
n_vocab: int
n_ctx: int
n_embed: int
n_hidden: int
n_head: int
n_layer: int
gradient_checkpointing: bool = False
class Model(nn.Module):
def __init__(self, hparams: HParams):
super().__init__()
self.hparams = hparams
self.wpe = nn.Embedding(hparams.n_ctx, hparams.n_embed)
nn.init.normal_(self.wpe.weight, std=0.01)
self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed)
nn.init.normal_(self.wte.weight, std=0.02)
self.blocks = nn.ModuleList(
[Block(hparams) for _ in range(hparams.n_layer)])
self.ln_f = Norm(self.hparams.n_hidden)
if hparams.n_hidden != hparams.n_embed:
self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden)
self.out_proj = Conv1D(hparams.n_hidden, hparams.n_embed)
else:
self.in_proj = self.out_proj = None
def forward(self, x, past=None):
# Embedding
past_length = 0 if past is None else past.shape[-2]
batch_size, n_ctx = x.shape
position = position_for(batch_size, n_ctx, past_length, x.device)
h = self.wte(x) + self.wpe(position)
assert h.shape == (batch_size, n_ctx, self.hparams.n_embed)
if self.in_proj:
h = self.in_proj(h)
# Transformer
presents = []
for i, block in enumerate(self.blocks):
if self.hparams.gradient_checkpointing:
h, present = torch.utils.checkpoint.checkpoint(
block, h, past[:, i] if past is not None else None)
else:
h, present = block(
h, past=past[:, i] if past is not None else None)
presents.append(present)
h = self.ln_f(h)
if self.out_proj:
h = self.out_proj(h)
# Output logits
h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed])
logits = torch.matmul(h_flat, self.wte.weight.t())
logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab])
return {
'presents': torch.stack(tuple(presents), dim=1),
'logits': logits,
}
class Block(nn.Module):
def __init__(self, hparams: HParams):
super().__init__()
self.ln_1 = Norm(hparams.n_hidden)
self.ln_2 = Norm(hparams.n_hidden)
self.mlp = MLP(hparams.n_hidden, hparams.n_hidden * 4)
self.attn = Attention(hparams)
def forward(self, x, past):
a, present = self.attn(self.ln_1(x), past=past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class Norm(nn.Module):
""" Normalize to mean = 0, std = 1, then do a diagonal affine transform.
"""
def __init__(self, n_features, *, dim=-1, epsilon=1e-5):
super().__init__()
self.n_features = n_features
self.dim = dim
self.epsilon = epsilon
self.g = nn.Parameter(torch.ones(n_features))
self.b = nn.Parameter(torch.zeros(n_features))
def forward(self, x):
assert x.shape[-1] == self.n_features
u = torch.mean(x, dim=self.dim, keepdim=True)
xmu = x - u
s = torch.mean(xmu * xmu, dim=self.dim, keepdim=True)
return xmu * torch.rsqrt(s + self.epsilon) * self.g + self.b
class MLP(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.c_fc = Conv1D(n_features, n_hidden)
self.c_proj = Conv1D(n_hidden, n_features)
def forward(self, x):
x = gelu(self.c_fc(x))
x = self.c_proj(x)
return x
class Attention(nn.Module):
def __init__(self, hparams: HParams):
super().__init__()
assert hparams.n_hidden % hparams.n_head == 0
self.hparams = hparams
self.c_attn = Conv1D(hparams.n_hidden, hparams.n_hidden * 3)
self.c_proj = Conv1D(hparams.n_hidden, hparams.n_hidden)
def forward(self, x, past):
assert len(x.shape) == 3 # [batch, sequence, features]
assert x.shape[-1] == self.hparams.n_hidden
if past is not None:
# Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
assert len(past.shape) == 5
assert past.shape[-1] == self.hparams.n_hidden
c = self.c_attn(x)
q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2))
present = torch.stack([k, v], dim=1)
if past is not None:
pk, pv = past[:, 0], past[:, 1]
k = torch.cat([pk, k], dim=-2)
v = torch.cat([pv, v], dim=-2)
a = self.multihead_attn(q, k, v)
a = self.merge_heads(a)
a = self.c_proj(a)
return a, present
def split_heads(self, x):
""" From [batch, sequence, features] to
[batch, heads, sequence, features].
"""
return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3)
@staticmethod
def split_states(x, n):
""" Reshape the last dimension of x into [n, x.shape[-1]/n].
"""
*start, m = x.shape
return x.reshape(start + [n, m // n])
def merge_heads(self, x):
""" Reverse of split_heads.
"""
return self.merge_states(x.permute(0, 2, 1, 3))
@staticmethod
def merge_states(x):
""" Smash the last two dimensions of x into a single dimension.
"""
*start, a, b = x.shape
return x.reshape(start + [a * b])
def mask_attn_weights(self, w):
# w has shape [batch, heads, dst_sequence, src_sequence],
# where information flows from src to dst.
_, _, nd, ns = w.shape
b = self.attention_mask(nd, ns, dtype=w.dtype, device=w.device)
b = b.reshape((1, 1, nd, ns))
w = w * b - 1e4 * (1 - b)
return w
@staticmethod
def attention_mask(nd, ns, *, dtype, device=None):
""" 1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd),
but doesn't produce garbage on TPUs.
"""
i = torch.arange(0, nd).unsqueeze(1)
j = torch.arange(ns)
return (i >= j - ns + nd).to(dtype=dtype, device=device)
def multihead_attn(self, q, k, v):
# q, k, v have shape [batch, heads, sequence, features]
w = torch.matmul(q, k.permute(0, 1, 3, 2))
w = w / math.sqrt(v.shape[-1])
w = self.mask_attn_weights(w)
w = F.softmax(w, dim=-1)
a = torch.matmul(w, v)
return a
class Conv1D(nn.Linear):
def reset_parameters(self):
nn.init.normal_(self.weight, std=0.02)
nn.init.zeros_(self.bias)
def gelu(x, c=math.sqrt(2 / math.pi)):
return 0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3))))
def position_for(batch_size, n_steps, past_length, device=None):
return (torch.arange(past_length, n_steps + past_length, device=device)
.unsqueeze(0).repeat(batch_size, 1))