File size: 4,864 Bytes
4042589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
self.time_ww = nn.Parameter(
torch.ones(config.n_head, config.ctx_len, config.ctx_len))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat(
[self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim=-1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30, min=-60)
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', self.time_ww[:,:T,:T], kv)
).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
return rwkv * self.time_gamma[:T, :]
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
hidden_sz = 5 * config.n_ffn // 2
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat(
[self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim=-1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.mish(k) * v)
rwkv = torch.sigmoid(r) * wkv
return rwkv
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.time_out = nn.Parameter(torch.ones(1, config.ctx_len, 1))
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_k = nn.Linear(config.n_embd, 256)
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.tok_emb(idx)
x = self.blocks(x)
x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :]
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss |