KaleiNeely commited on
Commit
1db31cb
1 Parent(s): 3edf4ac

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +140 -189
modeling_rwkv5.py CHANGED
@@ -1,6 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Bo Peng and HuggingFace Inc. team.
3
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -45,12 +44,6 @@ logger = logging.get_logger(__name__)
45
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
46
  _CONFIG_FOR_DOC = "Rwkv5Config"
47
 
48
- RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
- "RWKV/rwkv-5-world-1b5",
50
- "RWKV/rwkv-5-world-3b",
51
- # See all RWKV models at https://huggingface.co/models?filter=rwkv
52
- ]
53
-
54
  rwkv5_cuda_kernel = None
55
 
56
 
@@ -60,14 +53,14 @@ def load_wkv5_cuda_kernel(head_size):
60
 
61
  global rwkv5_cuda_kernel
62
 
63
- kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv5"
64
  cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
65
 
66
  # Only load the kernel if it's not been loaded yet or if we changed the context length
67
  if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
68
  return
69
 
70
- logger.info(f"Loading CUDA kernel for RWKV at head size of {head_size}.")
71
 
72
  flags = [
73
  "-res-usage",
@@ -87,39 +80,45 @@ def load_wkv5_cuda_kernel(head_size):
87
  rwkv5_cuda_kernel.head_size = head_size
88
 
89
 
90
- class WKV_5(torch.autograd.Function):
91
  @staticmethod
92
  def forward(ctx, receptance, key, value, time_decay, time_first, state):
93
  with torch.no_grad():
94
- Batch = key.shape[0]
95
- SequenceLength = key.shape[1]
96
- HiddenSize = key.shape[2]
97
- HeadSize = HiddenSize // time_decay.shape[0]
98
- ctx.Batch = Batch
99
- ctx.SequenceLength = SequenceLength
100
- ctx.HiddenSize = HiddenSize
101
- ctx.HeadSize = HeadSize
 
 
 
 
102
  e_time_decay = (-torch.exp(time_decay.float())).contiguous()
103
  ee_time_decay = (torch.exp(e_time_decay)).contiguous()
 
104
  ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
105
  out = torch.empty(
106
- (Batch, SequenceLength, HiddenSize),
107
  device=receptance.device,
108
  dtype=torch.bfloat16,
109
  memory_format=torch.contiguous_format,
110
  )
111
- rwkv5_cuda_kernel.forward(
112
- Batch,
113
- SequenceLength,
114
- HiddenSize,
115
- HeadSize,
 
 
116
  receptance,
117
  key,
118
  value,
119
  ee_time_decay,
120
  time_first,
121
  out,
122
- state,
123
  )
124
  return out, state
125
 
@@ -127,51 +126,55 @@ class WKV_5(torch.autograd.Function):
127
  def backward(ctx, gout):
128
  with torch.no_grad():
129
  assert gout.dtype == torch.bfloat16
130
- Batch = ctx.Batch
131
- SequenceLength = ctx.SequenceLength
132
- HiddenSize = ctx.HiddenSize
133
- HeadSize = ctx.HeadSize
134
  receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
 
 
 
 
135
  greceptance = torch.empty(
136
- (Batch, SequenceLength, HiddenSize),
137
  device=gout.device,
138
  requires_grad=False,
139
  dtype=torch.bfloat16,
140
  memory_format=torch.contiguous_format,
141
  )
142
  g_key = torch.empty(
143
- (Batch, SequenceLength, HiddenSize),
144
  device=gout.device,
145
  requires_grad=False,
146
  dtype=torch.bfloat16,
147
  memory_format=torch.contiguous_format,
148
  )
149
  g_value = torch.empty(
150
- (Batch, SequenceLength, HiddenSize),
151
  device=gout.device,
152
  requires_grad=False,
153
  dtype=torch.bfloat16,
154
  memory_format=torch.contiguous_format,
155
  )
156
  g_time_decay = torch.empty(
157
- (Batch, HiddenSize),
158
  device=gout.device,
159
  requires_grad=False,
160
  dtype=torch.bfloat16,
161
  memory_format=torch.contiguous_format,
162
  )
163
  g_time_first = torch.empty(
164
- (Batch, HiddenSize),
165
  device=gout.device,
166
  requires_grad=False,
167
  dtype=torch.bfloat16,
168
  memory_format=torch.contiguous_format,
169
  )
170
- rwkv5_cuda_kernel.backward(
171
- Batch,
172
- SequenceLength,
173
- HiddenSize,
174
- HeadSize,
175
  receptance,
176
  key,
177
  value,
@@ -185,133 +188,69 @@ class WKV_5(torch.autograd.Function):
185
  g_time_decay,
186
  g_time_first,
187
  )
188
- g_time_decay = torch.sum(g_time_decay, 0).view(HeadSize, HiddenSize // HeadSize)
189
- g_time_first = torch.sum(g_time_first, 0).view(HeadSize, HiddenSize // HeadSize)
 
190
  return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
191
 
192
 
193
- def rwkv_linear_attention_v5_cpu(
194
- hidden,
195
- time_decay,
196
- time_first,
197
- receptance,
198
- key,
199
- value,
200
- gate,
201
- layer_norm_weight,
202
- layer_norm_bias,
203
- output_weight,
204
- state,
205
- ):
206
- Batch = hidden.shape[0]
207
- AttentionHeads = time_decay.shape[0]
208
- HeadSize = hidden.shape[-1] // AttentionHeads
209
- SequenceLength = hidden.shape[1]
210
- key = key.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2).transpose(-2, -1)
211
- value = value.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
212
- receptance = receptance.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
213
- time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
214
- time_first = time_first.float().reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
215
- layer_norm_weight = layer_norm_weight.float()
216
- layer_norm_bias = layer_norm_bias.float()
217
- out = torch.zeros_like(key).reshape(Batch, SequenceLength, AttentionHeads, HeadSize)
218
- for t in range(SequenceLength):
219
- rt = receptance[:, :, t : t + 1, :]
220
- kt = key[:, :, :, t : t + 1]
221
- vt = value[:, :, t : t + 1, :]
222
- at = kt @ vt
223
- out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
224
  with torch.no_grad():
225
- state = at + time_decay * state
226
-
227
- out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
228
- out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
229
- Batch, SequenceLength, AttentionHeads * HeadSize
230
- )
231
- out = out.to(dtype=hidden.dtype) * gate
232
- out = out @ output_weight
233
 
234
  return out, state
235
 
236
-
237
- def rwkv_linear_attention(
238
- hidden,
239
- time_decay,
240
- time_first,
241
- receptance,
242
- key,
243
- value,
244
- gate,
245
- layer_norm_weight,
246
- layer_norm_bias,
247
- output_weight,
248
- state,
249
- ):
250
- Batch = hidden.shape[0]
251
- AttentionHeads = time_decay.shape[0]
252
- HeadSize = hidden.shape[-1] // AttentionHeads
253
- SequenceLength = hidden.shape[1]
254
- no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
255
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
256
  # in this case).
257
  one_token = key.size(1) == 1
258
- if rwkv5_cuda_kernel is None or no_cuda or one_token:
259
- return rwkv_linear_attention_v5_cpu(
260
- hidden,
261
- time_decay,
262
- time_first,
263
- receptance,
264
- key,
265
- value,
266
- gate,
267
- layer_norm_weight,
268
- layer_norm_bias,
269
- output_weight,
270
- state,
271
  )
272
  else:
273
- out, state = WKV_5.apply(
274
- Batch,
275
- SequenceLength,
276
- AttentionHeads * HeadSize,
277
- AttentionHeads,
278
- receptance,
279
- key,
280
- value,
281
- time_decay,
282
- time_first,
283
- state,
284
- )
285
- out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
286
- out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
287
- Batch, SequenceLength, AttentionHeads * HeadSize
288
- )
289
- out = out.to(dtype=hidden.dtype) * gate
290
- out = out @ output_weight
291
- return out, state
292
 
293
 
294
- class RwkvSelfAttention(nn.Module):
295
  def __init__(self, config, layer_id=0):
296
  super().__init__()
297
  self.config = config
298
  kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
299
  if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
300
  try:
301
- load_wkv5_cuda_kernel(config.context_length)
302
  except Exception:
303
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
304
  self.layer_id = layer_id
305
  hidden_size = config.hidden_size
306
- num_attention_heads = hidden_size // config.head_size
307
- self.num_attention_heads = num_attention_heads
308
- attention_hidden_size = (
309
- config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
310
- )
311
  self.attention_hidden_size = attention_hidden_size
 
 
312
 
313
- self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
314
- self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
315
  self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
316
 
317
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
@@ -324,7 +263,7 @@ class RwkvSelfAttention(nn.Module):
324
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
325
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
326
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
327
- self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
328
 
329
  def extract_key_value(self, hidden, state=None):
330
  # Mix hidden with the previous timestep to produce key, value, receptance
@@ -336,6 +275,7 @@ class RwkvSelfAttention(nn.Module):
336
  shifted[:, 0] = state[0][:, :, self.layer_id]
337
  if len(shifted.size()) == 2:
338
  shifted = shifted.unsqueeze(1)
 
339
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
340
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
341
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
@@ -353,28 +293,26 @@ class RwkvSelfAttention(nn.Module):
353
 
354
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
355
  receptance, key, value, gate, state = self.extract_key_value(hidden, state=state)
 
 
 
 
356
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
357
- rwkv, layer_state = rwkv_linear_attention(
358
- hidden,
359
- self.time_decay,
360
- self.time_faaaa,
361
- receptance,
362
- key,
363
- value,
364
- gate,
365
- self.ln_x.weight,
366
- self.ln_x.bias,
367
- self.output.weight.t(),
368
- state=layer_state,
369
  )
370
 
371
  if layer_state is not None:
372
  state[1][:, :, :, :, self.layer_id] = layer_state
373
 
374
- return rwkv, state
375
-
 
 
 
376
 
377
- class RwkvFeedForward(nn.Module):
 
378
  def __init__(self, config, layer_id=0):
379
  super().__init__()
380
  self.config = config
@@ -416,7 +354,7 @@ class RwkvFeedForward(nn.Module):
416
  return receptance * value, state
417
 
418
 
419
- # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
420
  class Rwkv5Block(nn.Module):
421
  def __init__(self, config, layer_id):
422
  super().__init__()
@@ -429,8 +367,8 @@ class Rwkv5Block(nn.Module):
429
  self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
430
  self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
431
 
432
- self.attention = RwkvSelfAttention(config, layer_id)
433
- self.feed_forward = RwkvFeedForward(config, layer_id)
434
 
435
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
436
  if self.layer_id == 0:
@@ -450,6 +388,7 @@ class Rwkv5Block(nn.Module):
450
  return outputs
451
 
452
 
 
453
  class Rwkv5PreTrainedModel(PreTrainedModel):
454
  """
455
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -457,19 +396,20 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
457
  """
458
 
459
  config_class = Rwkv5Config
460
- base_model_prefix = "rwkv"
461
  _no_split_modules = ["Rwkv5Block"]
462
  _keep_in_fp32_modules = ["time_decay", "time_first"]
463
  supports_gradient_checkpointing = True
464
 
465
  def _init_weights(self, module):
466
  """Initialize the weights."""
467
- if isinstance(module, RwkvSelfAttention):
468
  layer_id = module.layer_id
469
  num_hidden_layers = module.config.num_hidden_layers
470
  hidden_size = module.config.hidden_size
471
  attention_hidden_size = module.attention_hidden_size
472
- num_attention_heads = hidden_size // module.config.num_attention_heads
 
473
 
474
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
475
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
@@ -496,15 +436,15 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
496
  )
497
 
498
  with torch.no_grad():
499
- module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.num_attention_heads)
500
- module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.num_attention_heads)
501
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
502
 
503
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
504
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
505
  module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
506
 
507
- elif isinstance(module, RwkvFeedForward):
508
  layer_id = module.layer_id
509
  num_hidden_layers = module.config.num_hidden_layers
510
  hidden_size = module.config.hidden_size
@@ -523,11 +463,11 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
523
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
524
 
525
 
526
- # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
527
  @dataclass
528
  class Rwkv5Output(ModelOutput):
529
  """
530
- Class for the RWKV model outputs.
531
 
532
  Args:
533
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -551,7 +491,7 @@ class Rwkv5Output(ModelOutput):
551
  attentions: Optional[Tuple[torch.FloatTensor]] = None
552
 
553
 
554
- # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
555
  @dataclass
556
  class Rwkv5CausalLMOutput(ModelOutput):
557
  """
@@ -582,7 +522,7 @@ class Rwkv5CausalLMOutput(ModelOutput):
582
  attentions: Optional[Tuple[torch.FloatTensor]] = None
583
 
584
 
585
- RWKV_START_DOCSTRING = r"""
586
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
587
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
588
  etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
@@ -595,7 +535,7 @@ RWKV_START_DOCSTRING = r"""
595
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
596
  """
597
 
598
- RWKV_INPUTS_DOCSTRING = r"""
599
  Args:
600
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
601
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
@@ -625,8 +565,8 @@ RWKV_INPUTS_DOCSTRING = r"""
625
 
626
 
627
  @add_start_docstrings(
628
- "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
629
- RWKV_START_DOCSTRING,
630
  )
631
  class Rwkv5Model(Rwkv5PreTrainedModel):
632
  def __init__(self, config):
@@ -648,7 +588,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
648
  def set_input_embeddings(self, new_embeddings):
649
  self.embeddings = new_embeddings
650
 
651
- @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
652
  @add_code_sample_docstrings(
653
  checkpoint=_CHECKPOINT_FOR_DOC,
654
  output_type=Rwkv5Output,
@@ -669,6 +609,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
669
  output_hidden_states = (
670
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
671
  )
 
672
  # rwkv5 only support inference in huggingface.
673
  use_cache = use_cache if use_cache is not None else self.config.use_cache
674
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -686,9 +627,10 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
686
  if inputs_embeds is None:
687
  inputs_embeds = self.embeddings(input_ids)
688
 
689
- if use_cache and state is None:
690
  state = []
691
- num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
 
692
  state_attn_x = torch.zeros(
693
  (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
694
  dtype=inputs_embeds.dtype,
@@ -698,9 +640,9 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
698
  state_attn_kv = torch.zeros(
699
  (
700
  inputs_embeds.size(0),
701
- num_attention_heads,
702
- self.config.hidden_size // num_attention_heads,
703
- self.config.hidden_size // num_attention_heads,
704
  self.config.num_hidden_layers,
705
  ),
706
  dtype=torch.float32,
@@ -765,8 +707,16 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
765
  block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
766
  block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
767
  else:
768
- block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
769
- block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
 
 
 
 
 
 
 
 
770
 
771
  self.layers_are_rescaled = not self.training
772
 
@@ -798,8 +748,9 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
798
  The RWKV5 Model transformer with a language modeling head on top (linear layer with weights tied to the input
799
  embeddings).
800
  """,
801
- RWKV_START_DOCSTRING,
802
  )
 
803
  class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
804
  _tied_weights_keys = ["head.weight"]
805
 
@@ -831,7 +782,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
831
  model_inputs["state"] = state
832
  return model_inputs
833
 
834
- @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
835
  @add_code_sample_docstrings(
836
  checkpoint=_CHECKPOINT_FOR_DOC,
837
  output_type=Rwkv5CausalLMOutput,
@@ -857,7 +808,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
857
  """
858
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
859
 
860
- rwkv_outputs = self.rwkv(
861
  input_ids,
862
  inputs_embeds=inputs_embeds,
863
  state=state,
@@ -866,7 +817,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
866
  output_hidden_states=output_hidden_states,
867
  return_dict=return_dict,
868
  )
869
- hidden_states = rwkv_outputs[0]
870
 
871
  logits = self.head(hidden_states)
872
 
@@ -882,13 +833,13 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
882
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
883
 
884
  if not return_dict:
885
- output = (logits,) + rwkv_outputs[1:]
886
  return ((loss,) + output) if loss is not None else output
887
 
888
  return Rwkv5CausalLMOutput(
889
  loss=loss,
890
  logits=logits,
891
- state=rwkv_outputs.state,
892
- hidden_states=rwkv_outputs.hidden_states,
893
- attentions=rwkv_outputs.attentions,
894
  )
 
1
  # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
 
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
44
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
45
  _CONFIG_FOR_DOC = "Rwkv5Config"
46
 
 
 
 
 
 
 
47
  rwkv5_cuda_kernel = None
48
 
49
 
 
53
 
54
  global rwkv5_cuda_kernel
55
 
56
+ kernel_folder = Path(__file__).parent.resolve()
57
  cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
58
 
59
  # Only load the kernel if it's not been loaded yet or if we changed the context length
60
  if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
61
  return
62
 
63
+ logger.info(f"Loading CUDA kernel for RWKV5 at head size of {head_size}.")
64
 
65
  flags = [
66
  "-res-usage",
 
80
  rwkv5_cuda_kernel.head_size = head_size
81
 
82
 
83
+ class Rwkv5LinearAttention(torch.autograd.Function):
84
  @staticmethod
85
  def forward(ctx, receptance, key, value, time_decay, time_first, state):
86
  with torch.no_grad():
87
+ assert receptance.dtype == torch.bfloat16
88
+ assert key.dtype == torch.bfloat16
89
+ assert value.dtype == torch.bfloat16
90
+ assert time_decay.dtype == torch.bfloat16
91
+ assert time_first.dtype == torch.bfloat16
92
+ assert state.dtype == torch.float32
93
+ batch, seq_length, hidden_size = key.shape
94
+ num_heads = time_decay.shape[0]
95
+ ctx.batch = batch
96
+ ctx.seq_length = seq_length
97
+ ctx.hidden_size = hidden_size
98
+ ctx.num_heads = num_heads
99
  e_time_decay = (-torch.exp(time_decay.float())).contiguous()
100
  ee_time_decay = (torch.exp(e_time_decay)).contiguous()
101
+ assert ee_time_decay.dtype == torch.float32
102
  ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
103
  out = torch.empty(
104
+ (batch, seq_length, hidden_size),
105
  device=receptance.device,
106
  dtype=torch.bfloat16,
107
  memory_format=torch.contiguous_format,
108
  )
109
+ state = state.clone()
110
+ rwkv5_cuda_kernel.forward_bf16(
111
+ batch,
112
+ seq_length,
113
+ hidden_size,
114
+ num_heads,
115
+ state,
116
  receptance,
117
  key,
118
  value,
119
  ee_time_decay,
120
  time_first,
121
  out,
 
122
  )
123
  return out, state
124
 
 
126
  def backward(ctx, gout):
127
  with torch.no_grad():
128
  assert gout.dtype == torch.bfloat16
129
+ batch = ctx.batch
130
+ seq_length = ctx.seq_length
131
+ hidden_size = ctx.hidden_size
132
+ num_heads = ctx.num_heads
133
  receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
134
+
135
+ global_shape = (batch, seq_length, hidden_size)
136
+
137
+ # TODO dtype should not be forced here IMO
138
  greceptance = torch.empty(
139
+ global_shape,
140
  device=gout.device,
141
  requires_grad=False,
142
  dtype=torch.bfloat16,
143
  memory_format=torch.contiguous_format,
144
  )
145
  g_key = torch.empty(
146
+ global_shape,
147
  device=gout.device,
148
  requires_grad=False,
149
  dtype=torch.bfloat16,
150
  memory_format=torch.contiguous_format,
151
  )
152
  g_value = torch.empty(
153
+ global_shape,
154
  device=gout.device,
155
  requires_grad=False,
156
  dtype=torch.bfloat16,
157
  memory_format=torch.contiguous_format,
158
  )
159
  g_time_decay = torch.empty(
160
+ (batch, hidden_size),
161
  device=gout.device,
162
  requires_grad=False,
163
  dtype=torch.bfloat16,
164
  memory_format=torch.contiguous_format,
165
  )
166
  g_time_first = torch.empty(
167
+ (batch, hidden_size),
168
  device=gout.device,
169
  requires_grad=False,
170
  dtype=torch.bfloat16,
171
  memory_format=torch.contiguous_format,
172
  )
173
+ rwkv5_cuda_kernel.backward_bf16(
174
+ batch,
175
+ seq_length,
176
+ hidden_size,
177
+ num_heads,
178
  receptance,
179
  key,
180
  value,
 
188
  g_time_decay,
189
  g_time_first,
190
  )
191
+ head_size = hidden_size // num_heads
192
+ g_time_decay = torch.sum(g_time_decay, 0).view(num_heads, head_size)
193
+ g_time_first = torch.sum(g_time_first, 0).view(num_heads, head_size)
194
  return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
195
 
196
 
197
+ def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
198
+ input_dtype = receptance.dtype
199
+ # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
200
+ # within a torch.no_grad.
201
+ batch, seq_length, hidden_size = receptance.shape
202
+ num_heads, head_size = time_first.shape
203
+ key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
204
+ value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
205
+ receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
206
+ time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(num_heads, -1, 1)
207
+ time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
208
+ out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
209
+
210
+ for current_index in range(seq_length):
211
+ current_receptance = receptance[:, :, current_index:current_index+1, :]
212
+ current_key = key[:, :, :, current_index:current_index+1]
213
+ current_value = value[:, :, current_index:current_index+1, :]
214
+ attention_output = current_key @ current_value
215
+ out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
 
 
 
 
 
 
 
 
 
 
 
 
216
  with torch.no_grad():
217
+ state = attention_output + time_decay * state
 
 
 
 
 
 
 
218
 
219
  return out, state
220
 
221
+ # copied from RWKV but with receptance
222
+ def RWKV5_linear_attention(training, receptance, key, value, time_decay, time_first, state):
223
+ no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
225
  # in this case).
226
  one_token = key.size(1) == 1
227
+ if not training or rwkv5_cuda_kernel is None or no_cuda or one_token:
228
+ return rwkv5_linear_attention_cpu(
229
+ receptance, key, value, time_decay, time_first, state
 
 
 
 
 
 
 
 
 
 
230
  )
231
  else:
232
+ return Rwkv5LinearAttention.apply(receptance, key, value, time_decay, time_first, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
 
235
+ class Rwkv5SelfAttention(nn.Module):
236
  def __init__(self, config, layer_id=0):
237
  super().__init__()
238
  self.config = config
239
  kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
240
  if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
241
  try:
242
+ load_wkv5_cuda_kernel(config.head_size)
243
  except Exception:
244
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
245
  self.layer_id = layer_id
246
  hidden_size = config.hidden_size
247
+ attention_hidden_size = config.attention_hidden_size
 
 
 
 
248
  self.attention_hidden_size = attention_hidden_size
249
+ head_size = config.head_size
250
+ num_heads = attention_hidden_size // head_size
251
 
252
+ self.time_decay = nn.Parameter(torch.empty(num_heads, head_size))
253
+ self.time_faaaa = nn.Parameter(torch.empty(num_heads, head_size))
254
  self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
255
 
256
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
 
263
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
264
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
265
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
266
+ self.ln_x = nn.GroupNorm(num_heads, hidden_size)
267
 
268
  def extract_key_value(self, hidden, state=None):
269
  # Mix hidden with the previous timestep to produce key, value, receptance
 
275
  shifted[:, 0] = state[0][:, :, self.layer_id]
276
  if len(shifted.size()) == 2:
277
  shifted = shifted.unsqueeze(1)
278
+
279
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
280
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
281
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
 
293
 
294
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
295
  receptance, key, value, gate, state = self.extract_key_value(hidden, state=state)
296
+
297
+ B,T,C = receptance.shape
298
+ H, S = self.time_faaaa.shape
299
+
300
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
301
+ out, layer_state = RWKV5_linear_attention(
302
+ self.training, receptance, key, value, self.time_decay, self.time_faaaa, layer_state
 
 
 
 
 
 
 
 
 
 
303
  )
304
 
305
  if layer_state is not None:
306
  state[1][:, :, :, :, self.layer_id] = layer_state
307
 
308
+ out = out.reshape(B * T, H * S)
309
+ out = F.group_norm(out / self.config.head_size_divisor, num_groups=H, weight=self.ln_x.weight.to(out.dtype), bias=self.ln_x.bias.to(out.dtype), eps=self.ln_x.eps).reshape(B, T, H * S)
310
+ out = out.to(dtype=hidden.dtype) * gate
311
+ out = self.output(out)
312
+ return out, state
313
 
314
+ # Copied from rwkv exceot for the intermediate size
315
+ class Rwkv5FeedForward(nn.Module):
316
  def __init__(self, config, layer_id=0):
317
  super().__init__()
318
  self.config = config
 
354
  return receptance * value, state
355
 
356
 
357
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvBlock with Rwkv->Rwkv5
358
  class Rwkv5Block(nn.Module):
359
  def __init__(self, config, layer_id):
360
  super().__init__()
 
367
  self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
368
  self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
369
 
370
+ self.attention = Rwkv5SelfAttention(config, layer_id)
371
+ self.feed_forward = Rwkv5FeedForward(config, layer_id)
372
 
373
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
374
  if self.layer_id == 0:
 
388
  return outputs
389
 
390
 
391
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvPreTrainedModel with Rwkv->Rwkv5
392
  class Rwkv5PreTrainedModel(PreTrainedModel):
393
  """
394
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 
396
  """
397
 
398
  config_class = Rwkv5Config
399
+ base_model_prefix = "rwkv5"
400
  _no_split_modules = ["Rwkv5Block"]
401
  _keep_in_fp32_modules = ["time_decay", "time_first"]
402
  supports_gradient_checkpointing = True
403
 
404
  def _init_weights(self, module):
405
  """Initialize the weights."""
406
+ if isinstance(module, Rwkv5SelfAttention):
407
  layer_id = module.layer_id
408
  num_hidden_layers = module.config.num_hidden_layers
409
  hidden_size = module.config.hidden_size
410
  attention_hidden_size = module.attention_hidden_size
411
+ head_size = module.config.head_size
412
+ num_heads = attention_hidden_size // head_size
413
 
414
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
415
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
 
436
  )
437
 
438
  with torch.no_grad():
439
+ module.time_decay.data = decay_speed.reshape(num_heads, head_size)
440
+ module.time_faaaa.data = tmp.reshape(num_heads, head_size)
441
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
442
 
443
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
444
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
445
  module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
446
 
447
+ elif isinstance(module, Rwkv5FeedForward):
448
  layer_id = module.layer_id
449
  num_hidden_layers = module.config.num_hidden_layers
450
  hidden_size = module.config.hidden_size
 
463
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
464
 
465
 
466
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvOutput with Rwkv->Rwkv5
467
  @dataclass
468
  class Rwkv5Output(ModelOutput):
469
  """
470
+ Class for the RWKV5 model outputs.
471
 
472
  Args:
473
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
 
491
  attentions: Optional[Tuple[torch.FloatTensor]] = None
492
 
493
 
494
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvCausalLMOutput with Rwkv->Rwkv5
495
  @dataclass
496
  class Rwkv5CausalLMOutput(ModelOutput):
497
  """
 
522
  attentions: Optional[Tuple[torch.FloatTensor]] = None
523
 
524
 
525
+ RWKV5_START_DOCSTRING = r"""
526
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
527
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
528
  etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
 
535
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
536
  """
537
 
538
+ RWKV5_INPUTS_DOCSTRING = r"""
539
  Args:
540
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
541
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
 
565
 
566
 
567
  @add_start_docstrings(
568
+ "The bare RWKV5 Model transformer outputting raw hidden-states without any specific head on top.",
569
+ RWKV5_START_DOCSTRING,
570
  )
571
  class Rwkv5Model(Rwkv5PreTrainedModel):
572
  def __init__(self, config):
 
588
  def set_input_embeddings(self, new_embeddings):
589
  self.embeddings = new_embeddings
590
 
591
+ @add_start_docstrings_to_model_forward(RWKV5_INPUTS_DOCSTRING)
592
  @add_code_sample_docstrings(
593
  checkpoint=_CHECKPOINT_FOR_DOC,
594
  output_type=Rwkv5Output,
 
609
  output_hidden_states = (
610
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
611
  )
612
+ # FIXME - training is supportable with the CUDA code
613
  # rwkv5 only support inference in huggingface.
614
  use_cache = use_cache if use_cache is not None else self.config.use_cache
615
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
627
  if inputs_embeds is None:
628
  inputs_embeds = self.embeddings(input_ids)
629
 
630
+ if state is None:
631
  state = []
632
+ head_size = self.config.head_size
633
+ num_heads = self.config.attention_hidden_size // head_size
634
  state_attn_x = torch.zeros(
635
  (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
636
  dtype=inputs_embeds.dtype,
 
640
  state_attn_kv = torch.zeros(
641
  (
642
  inputs_embeds.size(0),
643
+ num_heads,
644
+ head_size,
645
+ head_size,
646
  self.config.num_hidden_layers,
647
  ),
648
  dtype=torch.float32,
 
707
  block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
708
  block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
709
  else:
710
+ # Deal with quantization statistics
711
+ if hasattr(block.attention.output.weight, "SCB"):
712
+ block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
713
+ block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
714
+ elif hasattr(block.attention.output.weight, "quant_state"):
715
+ self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
716
+ self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
717
+ else:
718
+ block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
719
+ block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
720
 
721
  self.layers_are_rescaled = not self.training
722
 
 
748
  The RWKV5 Model transformer with a language modeling head on top (linear layer with weights tied to the input
749
  embeddings).
750
  """,
751
+ RWKV5_START_DOCSTRING,
752
  )
753
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvForCausalLM with Rwkv->Rwkv5
754
  class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
755
  _tied_weights_keys = ["head.weight"]
756
 
 
782
  model_inputs["state"] = state
783
  return model_inputs
784
 
785
+ @add_start_docstrings_to_model_forward(RWKV5_INPUTS_DOCSTRING)
786
  @add_code_sample_docstrings(
787
  checkpoint=_CHECKPOINT_FOR_DOC,
788
  output_type=Rwkv5CausalLMOutput,
 
808
  """
809
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
810
 
811
+ outputs = self.rwkv(
812
  input_ids,
813
  inputs_embeds=inputs_embeds,
814
  state=state,
 
817
  output_hidden_states=output_hidden_states,
818
  return_dict=return_dict,
819
  )
820
+ hidden_states = outputs[0]
821
 
822
  logits = self.head(hidden_states)
823
 
 
833
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
834
 
835
  if not return_dict:
836
+ output = (logits,) + outputs[1:]
837
  return ((loss,) + output) if loss is not None else output
838
 
839
  return Rwkv5CausalLMOutput(
840
  loss=loss,
841
  logits=logits,
842
+ state=outputs.state,
843
+ hidden_states=outputs.hidden_states,
844
+ attentions=outputs.attentions,
845
  )