KaleiNeely commited on
Commit
b4b4bb8
1 Parent(s): 9b7bad7

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +13 -9
modeling_rwkv5.py CHANGED
@@ -22,6 +22,7 @@ import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
  from torch import nn
 
25
 
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.utils import (
@@ -42,6 +43,7 @@ _CONFIG_FOR_DOC = "Rwkv5Config"
42
 
43
  RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
  "RWKV/rwkv-5-world-1b5",
 
45
  # See all RWKV models at https://huggingface.co/models?filter=rwkv
46
  ]
47
 
@@ -63,22 +65,20 @@ def rwkv_linear_attention_v5(
63
  lxb,
64
  ow,
65
  state,
66
- return_state=False,
67
- seq_mode=True,
68
  ):
69
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
70
  time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
71
  lxw = lxw.float()
72
  lxb = lxb.float()
73
- # if seq_mode:
74
- out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
75
  for t in range(T):
76
  rt = receptance[:, :, t : t + 1, :]
77
  kt = key[:, :, :, t : t + 1]
78
  vt = value[:, :, t : t + 1, :]
79
  at = kt @ vt
80
  out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
81
- state = at + time_decay * state
 
82
 
83
  out = out.reshape(B * T, H * S)
84
  out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
@@ -171,8 +171,6 @@ class RwkvSelfAttention(nn.Module):
171
  self.ln_x.bias,
172
  self.output.weight.t(),
173
  state=layer_state,
174
- return_state=use_cache,
175
- seq_mode=seq_mode,
176
  )
177
 
178
  if layer_state is not None:
@@ -671,8 +669,14 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
671
 
672
  loss = None
673
  if labels is not None:
674
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L984
675
- loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
 
 
 
 
 
 
676
 
677
  if not return_dict:
678
  output = (logits,) + rwkv_outputs[1:]
 
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
  from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
 
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import (
 
43
 
44
  RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
  "RWKV/rwkv-5-world-1b5",
46
+ "RWKV/rwkv-5-world-3b",
47
  # See all RWKV models at https://huggingface.co/models?filter=rwkv
48
  ]
49
 
 
65
  lxb,
66
  ow,
67
  state,
 
 
68
  ):
69
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
70
  time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
71
  lxw = lxw.float()
72
  lxb = lxb.float()
73
+ out = torch.zeros_like(key).reshape(B, T, H, S)
 
74
  for t in range(T):
75
  rt = receptance[:, :, t : t + 1, :]
76
  kt = key[:, :, :, t : t + 1]
77
  vt = value[:, :, t : t + 1, :]
78
  at = kt @ vt
79
  out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
80
+ with torch.no_grad():
81
+ state = at + time_decay * state
82
 
83
  out = out.reshape(B * T, H * S)
84
  out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
 
171
  self.ln_x.bias,
172
  self.output.weight.t(),
173
  state=layer_state,
 
 
174
  )
175
 
176
  if layer_state is not None:
 
669
 
670
  loss = None
671
  if labels is not None:
672
+ # move labels to correct device to enable model parallelism
673
+ labels = labels.to(logits.device)
674
+ # Shift so that tokens < n predict n
675
+ shift_logits = logits[..., :-1, :].contiguous()
676
+ shift_labels = labels[..., 1:].contiguous()
677
+ # Flatten the tokens
678
+ loss_fct = CrossEntropyLoss()
679
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
680
 
681
  if not return_dict:
682
  output = (logits,) + rwkv_outputs[1:]