|
""" |
|
Full definition of a RWKV Language Model, all of it in this single file. |
|
References: |
|
1) the official RWKV PyTorch implementation released by Bo Peng: |
|
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py |
|
2) huggingface/transformers PyTorch implementation: |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py |
|
""" |
|
|
|
|
|
import math,time |
|
import os |
|
import inspect |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
PREV_X_TIME = 0 |
|
NUM_STATE = 1 |
|
DEN_STATE = 2 |
|
MAX_STATE = 3 |
|
PREV_X_CHANNEL = 4 |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ |
|
|
|
def __init__(self, ndim, bias): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
def forward(self, input): |
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
|
from unittest.mock import patch |
|
class CudaNotAvailable: |
|
def __enter__(self): |
|
self.patcher = patch("torch.cuda.is_available", return_value=False) |
|
self.patcher.start() |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
self.patcher.stop() |
|
|
|
|
|
class L2Wrap(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, loss, y): |
|
ctx.save_for_backward(y) |
|
return loss |
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
y = ctx.saved_tensors[0] |
|
|
|
factor = 1e-4 / (y.shape[0] * y.shape[1]) |
|
maxx, ids = torch.max(y, -1, keepdim=True) |
|
gy = torch.zeros_like(y) |
|
gy.scatter_(-1, ids, maxx * factor) |
|
return (grad_output, gy) |
|
|
|
class ChannelMixing(nn.Module): |
|
def __init__(self,config,layer_id): |
|
super().__init__() |
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
|
self.layer_id = layer_id |
|
|
|
n_embd = config.n_embd |
|
intermediate_size = ( |
|
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd |
|
) |
|
|
|
|
|
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False) |
|
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False) |
|
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False) |
|
|
|
|
|
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) |
|
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) |
|
|
|
def forward(self,x,state=None): |
|
|
|
if state is not None: |
|
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:] |
|
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x |
|
else: |
|
prev_x = self.time_shift(x) |
|
|
|
|
|
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) |
|
receptance = self.receptance_proj(receptance) |
|
receptance = F.sigmoid(receptance) |
|
|
|
|
|
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) |
|
key = self.key_proj(key) |
|
|
|
|
|
value = self.value_proj(torch.square(torch.relu(key))) |
|
|
|
|
|
out = receptance * value |
|
return out, state |
|
|
|
class TimeMixing(nn.Module): |
|
def __init__(self,config,layer_id): |
|
super().__init__() |
|
self.config = config |
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
|
self.layer_id = layer_id |
|
|
|
n_embd = config.n_embd |
|
attn_sz = n_embd |
|
|
|
|
|
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False) |
|
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False) |
|
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False) |
|
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False) |
|
|
|
|
|
self.time_decay = nn.Parameter(torch.empty(attn_sz)) |
|
self.time_first = nn.Parameter(torch.empty(attn_sz)) |
|
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) |
|
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd)) |
|
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) |
|
|
|
def forward(self,x,state=None): |
|
|
|
if state is not None: |
|
prev_x = state[self.layer_id,:,[PREV_X_TIME],:] |
|
state[self.layer_id,:,[PREV_X_TIME],:] = x |
|
else: |
|
prev_x = self.time_shift(x) |
|
|
|
|
|
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) |
|
key = self.key_proj(key) |
|
|
|
|
|
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value) |
|
value = self.value_proj(value) |
|
|
|
|
|
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) |
|
receptance = self.receptance_proj(receptance) |
|
receptance = F.sigmoid(receptance) |
|
|
|
|
|
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state) |
|
|
|
|
|
rwkv = receptance * wkv |
|
rwkv = self.output_proj(rwkv) |
|
|
|
return rwkv, state |
|
|
|
|
|
def wkv_function(self,key,value,use_customized_cuda_kernel,state=None): |
|
|
|
|
|
|
|
if state is None and use_customized_cuda_kernel: |
|
B, T, C = key.size() |
|
return WKVKernel.apply(B, T, C, self.time_decay, self.time_first, key, value), None |
|
|
|
|
|
|
|
else: |
|
_, seq_length, _ = key.size() |
|
output = torch.zeros_like(key) |
|
|
|
debug_mode = False |
|
if state is None: |
|
|
|
debug_mode = True |
|
num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) |
|
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) |
|
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 |
|
else: |
|
num_state = state[self.layer_id,:,NUM_STATE,:] |
|
den_state = state[self.layer_id,:,DEN_STATE,:] |
|
max_state = state[self.layer_id,:,MAX_STATE,:] |
|
|
|
time_decay = -torch.exp(self.time_decay) |
|
|
|
for current_index in range(seq_length): |
|
current_key = key[:, current_index].float() |
|
current_value = value[:, current_index] |
|
|
|
|
|
max_for_output = torch.maximum(max_state, current_key + self.time_first) |
|
e1 = torch.exp(max_state - max_for_output) |
|
e2 = torch.exp(current_key + self.time_first - max_for_output) |
|
numerator = e1 * num_state + e2 * current_value |
|
denominator = e1 * den_state + e2 |
|
output[:, current_index] = (numerator / denominator).to(output.dtype) |
|
|
|
|
|
max_for_state = torch.maximum(max_state + time_decay, current_key) |
|
e1 = torch.exp(max_state + time_decay - max_for_state) |
|
e2 = torch.exp(current_key - max_for_state) |
|
num_state = e1 * num_state + e2 * current_value |
|
den_state = e1 * den_state + e2 |
|
max_state = max_for_state |
|
|
|
if debug_mode: |
|
return output, None |
|
|
|
else: |
|
state[self.layer_id,:,NUM_STATE,:] = num_state |
|
state[self.layer_id,:,DEN_STATE,:] = den_state |
|
state[self.layer_id,:,MAX_STATE,:] = max_state |
|
|
|
return output, state |
|
|
|
class Block(nn.Module): |
|
|
|
def __init__(self, config,layer_id): |
|
super().__init__() |
|
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.attn = TimeMixing(config,layer_id) |
|
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.ffn = ChannelMixing(config,layer_id) |
|
|
|
def forward(self, x, state = None): |
|
|
|
|
|
|
|
residual = x |
|
x,state = self.attn(self.ln_1(x),state=state) |
|
x = x + residual |
|
|
|
|
|
residual = x |
|
x, state = self.ffn(self.ln_2(x),state=state) |
|
x = x + residual |
|
|
|
return x, state |
|
|
|
@dataclass |
|
class RWKVConfig: |
|
block_size: int = 1024 |
|
vocab_size: int = 50304 |
|
n_layer: int = 12 |
|
n_embd: int = 768 |
|
bias: bool = True |
|
intermediate_size: int = None |
|
use_customized_cuda_kernel: bool = True |
|
dtype: str = "float16" |
|
rescale_every: int = 6 |
|
|
|
class RWKV(nn.Module): |
|
|
|
def __init__(self, config,lr_init=0.0008): |
|
super().__init__() |
|
assert config.vocab_size is not None |
|
assert config.block_size is not None |
|
self.config = config |
|
self.lr_init = lr_init |
|
self.rwkv = nn.ModuleDict(dict( |
|
wte = nn.Embedding(config.vocab_size, config.n_embd), |
|
ln_p = LayerNorm(config.n_embd, bias=config.bias), |
|
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]), |
|
ln_f = LayerNorm(config.n_embd, bias=config.bias), |
|
)) |
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
self.apply(self._init_weights) |
|
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) |
|
|
|
if self.config.use_customized_cuda_kernel: |
|
|
|
self.load_cuda_kernel(config.dtype) |
|
|
|
def get_num_params(self, non_embedding=True): |
|
""" |
|
Return the number of parameters in the model. |
|
For non-embedding count (default), the token embeddings get subtracted. |
|
""" |
|
n_params = sum(p.numel() for p in self.parameters()) |
|
if non_embedding: |
|
n_params -= self.rwkv.wte.weight.numel() |
|
return n_params |
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
if isinstance(module,TimeMixing): |
|
layer_id = module.layer_id |
|
n_layer = self.config.n_layer |
|
n_embd = self.config.n_embd |
|
attn_sz = n_embd |
|
|
|
with torch.no_grad(): |
|
ratio_0_to_1 = layer_id / (n_layer - 1) |
|
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) |
|
ddd = torch.ones(1, 1, n_embd) |
|
for i in range(n_embd): |
|
ddd[0, 0, i] = i / n_embd |
|
|
|
decay_speed = torch.ones(attn_sz) |
|
for h in range(attn_sz): |
|
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
|
module.time_decay = nn.Parameter(decay_speed) |
|
|
|
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 |
|
module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) |
|
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) |
|
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) |
|
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
|
|
|
|
|
elif isinstance(module,ChannelMixing): |
|
layer_id = module.layer_id |
|
n_layer = self.config.n_layer |
|
n_embd = self.config.n_embd |
|
|
|
with torch.no_grad(): |
|
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) |
|
ddd = torch.ones(1, 1, n_embd) |
|
for i in range(n_embd): |
|
ddd[0, 0, i] = i / n_embd |
|
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) |
|
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) |
|
|
|
|
|
elif isinstance(module,(nn.Embedding,nn.Linear)): |
|
weight = module.weight |
|
shape = weight.shape |
|
gain = 1.0 |
|
scale = 1.0 |
|
|
|
|
|
for _name,_parameters in self.named_parameters(): |
|
if id(_parameters) == id(weight): |
|
current_module_name = _name |
|
|
|
|
|
|
|
|
|
if isinstance(module, nn.Embedding): |
|
gain = math.sqrt(max(shape[0], shape[1])) |
|
scale = -1 * self.lr_init |
|
|
|
|
|
elif isinstance(module,nn.Linear): |
|
if shape[0] > shape[1]: |
|
gain = math.sqrt(shape[0] / shape[1]) |
|
|
|
|
|
for name in [".attn.key_proj.", ".attn.receptance_proj.", ".attn.output_proj.", |
|
".ffn.value_proj.", ".ffn.receptance_proj."]: |
|
if name in current_module_name: |
|
scale = 0 |
|
|
|
if current_module_name == 'lm_head.weight': |
|
scale = 0.5 |
|
|
|
if scale == 0: |
|
nn.init.zeros_(weight) |
|
elif scale < 0: |
|
nn.init.uniform_(weight, a=scale, b=-scale) |
|
else: |
|
nn.init.orthogonal_(weight, gain=gain * scale) |
|
|
|
def forward(self, idx, targets=None, state=None, return_state=False): |
|
|
|
device = idx.device |
|
b, t = idx.size() |
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
|
x = self.rwkv.wte(idx) |
|
x = self.rwkv.ln_p(x) |
|
|
|
for block_idx,block in enumerate(self.rwkv.h): |
|
x, state = block(x,state) |
|
if state is not None: |
|
if ( |
|
self.config.rescale_every > 0 |
|
and (block_idx + 1) % self.config.rescale_every == 0 |
|
): |
|
x = x/2 |
|
x = self.rwkv.ln_f(x) |
|
|
|
if targets is not None: |
|
|
|
logits = self.lm_head(x) |
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
if self.training: |
|
loss = L2Wrap.apply(loss,logits) |
|
else: |
|
|
|
logits = self.lm_head(x[:, [-1], :]) |
|
loss = None |
|
|
|
if return_state: |
|
return logits, loss, state |
|
else: |
|
return logits, loss |
|
|
|
def crop_block_size(self, block_size): |
|
assert block_size <= self.config.block_size |
|
self.config.block_size = block_size |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_type,use_customized_cuda_kernel=True,dtype="float16"): |
|
assert model_type in { |
|
'RWKV/rwkv-4-169m-pile', |
|
"RWKV/rwkv-4-430m-pile", |
|
"RWKV/rwkv-4-1b5-pile", |
|
"RWKV/rwkv-4-3b-pile", |
|
"RWKV/rwkv-4-7b-pile", |
|
"RWKV/rwkv-raven-7b", |
|
"RWKV/rwkv-raven-1b5", |
|
"RWKV/rwkv-raven-3b", |
|
"RWKV/rwkv-4-14b-pile", |
|
} |
|
print("loading weights from pretrained RWKV: %s" % model_type) |
|
|
|
|
|
from transformers import RwkvForCausalLM,RwkvConfig |
|
hf_config = RwkvConfig.from_pretrained(model_type) |
|
with CudaNotAvailable(): |
|
hf_model = RwkvForCausalLM.from_pretrained(model_type) |
|
|
|
|
|
config = { |
|
"vocab_size":50277, |
|
"n_layer":hf_config.num_hidden_layers, |
|
"n_embd":hf_config.hidden_size, |
|
"intermediate_size":hf_config.intermediate_size, |
|
"use_customized_cuda_kernel":use_customized_cuda_kernel, |
|
"dtype": dtype, |
|
} |
|
config = RWKVConfig(**config) |
|
model = RWKV(config) |
|
num_layers = config.n_layer |
|
|
|
mapping = { |
|
"rwkv.wte.weight":"rwkv.embeddings.weight", |
|
"rwkv.ln_p.weight":"rwkv.blocks.0.pre_ln.weight", |
|
"rwkv.ln_p.bias":"rwkv.blocks.0.pre_ln.bias", |
|
"rwkv.ln_f.weight":"rwkv.ln_out.weight", |
|
"rwkv.ln_f.bias":"rwkv.ln_out.bias", |
|
"lm_head.weight":"head.weight", |
|
**{f"rwkv.h.{layer_id}.ln_{norm_id}.weight":f"rwkv.blocks.{layer_id}.ln{norm_id}.weight" for layer_id in range(num_layers) for norm_id in [1,2]}, |
|
**{f"rwkv.h.{layer_id}.ln_{norm_id}.bias":f"rwkv.blocks.{layer_id}.ln{norm_id}.bias" for layer_id in range(num_layers) for norm_id in [1,2]}, |
|
**{f"rwkv.h.{layer_id}.attn.{_type}":f"rwkv.blocks.{layer_id}.attention.{_type}" for layer_id in range(num_layers) for _type in ["time_decay","time_first",'time_mix_key','time_mix_value',"time_mix_receptance"]}, |
|
**{f"rwkv.h.{layer_id}.attn.{_type}_proj.weight":f"rwkv.blocks.{layer_id}.attention.{_type}.weight" for layer_id in range(num_layers) for _type in ["key","value",'receptance',"output"]}, |
|
**{f"rwkv.h.{layer_id}.ffn.{_type}":f"rwkv.blocks.{layer_id}.feed_forward.{_type}" for layer_id in range(num_layers) for _type in ['time_mix_key',"time_mix_receptance"]}, |
|
**{f"rwkv.h.{layer_id}.ffn.{_type}_proj.weight":f"rwkv.blocks.{layer_id}.feed_forward.{_type}.weight" for layer_id in range(num_layers) for _type in ["key","value",'receptance']}, |
|
} |
|
|
|
mapped_set = [mapping[x] for x in model.state_dict().keys()] |
|
assert set(mapped_set) == set(hf_model.state_dict().keys()) |
|
sd = model.state_dict() |
|
hf_sd = hf_model.state_dict() |
|
|
|
for k1,k2 in mapping.items(): |
|
assert sd[k1].shape == hf_sd[k2].shape,(k1,k2) |
|
sd[k1].copy_(hf_sd[k2]) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): |
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
|
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
optim_groups = [ |
|
{'params': decay_params, 'weight_decay': weight_decay}, |
|
{'params': nodecay_params, 'weight_decay': 0.0} |
|
] |
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") |
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") |
|
|
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters |
|
use_fused = fused_available and device_type == 'cuda' |
|
extra_args = dict(fused=True) if use_fused else dict() |
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) |
|
print(f"using fused AdamW: {use_fused}") |
|
|
|
return optimizer |
|
|
|
def estimate_mfu(self, fwdbwd_per_iter, dt): |
|
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ |
|
|
|
|
|
cfg = self.config |
|
L, V, D = cfg.n_layer, cfg.vocab_size, cfg.n_embd |
|
|
|
|
|
flops_per_token = 2*(V*D + 13*(V**2)*L) |
|
flops_per_fwdbwd = 3*flops_per_token |
|
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter |
|
|
|
flops_achieved = flops_per_iter * (1.0/dt) |
|
|
|
if cfg.dtype == 'bfloat16': |
|
flops_promised = 312e12 |
|
elif cfg.dtype == 'float16': |
|
flops_promised = 312e12 |
|
else: |
|
flops_promised = 19.5e12 |
|
mfu = flops_achieved / flops_promised |
|
return mfu |
|
|
|
def init_state(self,batch_size,device): |
|
|
|
n_state = len([PREV_X_TIME,NUM_STATE,DEN_STATE,MAX_STATE,PREV_X_CHANNEL]) |
|
state = torch.zeros( |
|
(self.config.n_layer,batch_size,n_state,self.config.n_embd), |
|
dtype=torch.float32, device=device, |
|
) |
|
state[:,:,MAX_STATE,:] -= 1e30 |
|
|
|
return state |
|
|
|
def scale_parameters(self): |
|
if self.config.rescale_every > 0: |
|
with torch.no_grad(): |
|
for block_id,block in enumerate(self.rwkv.h): |
|
block.attn.output_proj.weight.div_(2 ** int(block_id // self.config.rescale_every)) |
|
block.ffn.value_proj.weight.div_(2 ** int(block_id // self.config.rescale_every)) |
|
self.scaled = True |
|
|
|
def unscale_parameters(self): |
|
if self.config.rescale_every > 0 and self.scaled: |
|
with torch.no_grad(): |
|
for block_id,block in enumerate(self.rwkv.h): |
|
block.attn.output_proj.weight.mul_(2 ** int(block_id // self.config.rescale_every)) |
|
block.ffn.value_proj.weight.mul_(2 ** int(block_id // self.config.rescale_every)) |
|
|
|
@torch.no_grad() |
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
""" |
|
idx: (batch_size,seq_len) |
|
""" |
|
batch_size,seq_len = idx.shape |
|
state = self.init_state(batch_size,idx.device) |
|
for seq_id in range(seq_len): |
|
logits, _, state = self(idx[:,[seq_id]], state = state, return_state=True) |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
logits, _, state = self(idx_next, state=state, return_state=True) |
|
return idx |
|
|
|
def load_cuda_kernel(self,dtype): |
|
|
|
from torch.utils.cpp_extension import load |
|
T_MAX = self.config.block_size |
|
RWKV_FLOAT_MODE = dtype |
|
if RWKV_FLOAT_MODE == "bfloat16": |
|
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) |
|
class WKV(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, B, T, C, w, u, k, v): |
|
ctx.B = B |
|
ctx.T = T |
|
ctx.C = C |
|
assert T <= T_MAX |
|
assert B * C % min(C, 32) == 0 |
|
w = -torch.exp(w.float().contiguous()) |
|
u = u.contiguous().bfloat16() |
|
k = k.contiguous() |
|
v = v.contiguous() |
|
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) |
|
wkv_cuda.forward(B, T, C, w, u, k, v, y) |
|
ctx.save_for_backward(w, u, k, v, y) |
|
return y |
|
@staticmethod |
|
def backward(ctx, gy): |
|
B = ctx.B |
|
T = ctx.T |
|
C = ctx.C |
|
assert T <= T_MAX |
|
assert B * C % min(C, 32) == 0 |
|
w, u, k, v, y = ctx.saved_tensors |
|
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) |
|
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) |
|
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) |
|
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) |
|
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) |
|
gw = torch.sum(gw, dim=0) |
|
gu = torch.sum(gu, dim=0) |
|
return (None, None, None, gw, gu, gk, gv) |
|
else: |
|
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) |
|
class WKV(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, B, T, C, w, u, k, v): |
|
ctx.B = B |
|
ctx.T = T |
|
ctx.C = C |
|
assert T <= T_MAX |
|
assert B * C % min(C, 32) == 0 |
|
if "32" in RWKV_FLOAT_MODE: |
|
w = -torch.exp(w.contiguous()) |
|
u = u.contiguous() |
|
k = k.contiguous() |
|
v = v.contiguous() |
|
else: |
|
w = -torch.exp(w.float().contiguous()) |
|
u = u.float().contiguous() |
|
k = k.float().contiguous() |
|
v = v.float().contiguous() |
|
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) |
|
wkv_cuda.forward(B, T, C, w, u, k, v, y) |
|
ctx.save_for_backward(w, u, k, v, y) |
|
if "32" in RWKV_FLOAT_MODE: |
|
return y |
|
elif RWKV_FLOAT_MODE == "float16": |
|
return y.half() |
|
|
|
@staticmethod |
|
def backward(ctx, gy): |
|
B = ctx.B |
|
T = ctx.T |
|
C = ctx.C |
|
assert T <= T_MAX |
|
assert B * C % min(C, 32) == 0 |
|
w, u, k, v, y = ctx.saved_tensors |
|
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) |
|
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) |
|
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) |
|
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) |
|
if "32" in RWKV_FLOAT_MODE: |
|
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) |
|
else: |
|
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) |
|
gw = torch.sum(gw, dim=0) |
|
gu = torch.sum(gu, dim=0) |
|
if "32" in RWKV_FLOAT_MODE: |
|
return (None, None, None, gw, gu, gk, gv) |
|
elif RWKV_FLOAT_MODE == "float16": |
|
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) |
|
|
|
global WKVKernel |
|
WKVKernel = WKV |
|
|