KaleiNeely commited on
Commit
3a3a37f
1 Parent(s): da02d99

Update modeling_rwkv6.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv6.py +19 -97
modeling_rwkv6.py CHANGED
@@ -37,6 +37,13 @@ from transformers.utils import (
37
  )
38
 
39
  from .configuration_rwkv6 import Rwkv6Config
 
 
 
 
 
 
 
40
 
41
 
42
  logger = logging.get_logger(__name__)
@@ -44,102 +51,15 @@ logger = logging.get_logger(__name__)
44
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-6-world-1b6"
45
  _CONFIG_FOR_DOC = "Rwkv6Config"
46
 
47
- rwkv6_cuda_kernel = None
48
-
49
- def load_wkv6_cuda_kernel(head_size, ctx_len):
50
- from torch.utils.cpp_extension import load as load_kernel
51
-
52
- global rwkv6_cuda_kernel
53
-
54
- kernel_folder = Path(__file__).parent.resolve()
55
- cuda_kernel_files = [kernel_folder / f for f in ["wkv6_op.cpp", "wkv6_cuda.cu"]]
56
-
57
- # Only load the kernel if it's not been loaded yet or if we changed the context length
58
- if rwkv6_cuda_kernel is not None and rwkv6_cuda_kernel.head_size == head_size:
59
- return
60
-
61
- logger.info(f"Loading CUDA kernel for RWKV at head size of {head_size}.")
62
-
63
- flags = [
64
- "-res-usage",
65
- # "--maxrregcount 60", # not sure, should we add this? its not in RWKV-LM
66
- "--use_fast_math",
67
- "-O3",
68
- "-Xptxas -O3",
69
- "--extra-device-vectorization",
70
- f"-D_N_={head_size}",
71
- f"-D_T_={ctx_len}"
72
- ]
73
- rwkv6_cuda_kernel = load_kernel(
74
- name=f"wkv_{head_size}_{ctx_len}",
75
- sources=cuda_kernel_files,
76
- verbose=(logging.get_verbosity() == logging.DEBUG),
77
- extra_cuda_cflags=flags,
78
- )
79
- rwkv6_cuda_kernel.head_size = head_size
80
- rwkv6_cuda_kernel.ctx_len = ctx_len
81
-
82
-
83
- class Rwkv6LinearAttention(torch.autograd.Function):
84
- @staticmethod
85
- def forward(ctx, receptance, key, value, time_decay, time_first, state):
86
- with torch.no_grad():
87
- assert receptance.dtype == torch.bfloat16
88
- assert key.dtype == torch.bfloat16
89
- assert value.dtype == torch.bfloat16
90
- assert time_decay.dtype == torch.bfloat16
91
- assert time_first.dtype == torch.bfloat16
92
- assert state.dtype == torch.float32
93
- #assert HEAD_SIZE == C // H
94
- Batch, SequenceLength, HiddenSize = key.shape
95
- NumHeads, HeadSize = time_decay.shape
96
- ctx.Batch = Batch
97
- ctx.SequenceLength = SequenceLength
98
- ctx.HiddenSize = HiddenSize
99
- ctx.NumHeads = NumHeads
100
- assert receptance.is_contiguous()
101
- assert key.is_contiguous()
102
- assert value.is_contiguous()
103
- assert time_decay.is_contiguous()
104
- assert time_first.is_contiguous()
105
- e_time_decay = (-torch.exp(time_decay.float())).contiguous()
106
- ctx.save_for_backward(receptance, key, value, e_time_decay, time_first)
107
- out = torch.empty((Batch, SequenceLength, HiddenSize), device=receptance.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
108
- # FIXME - current kernel does not handle nor update state
109
- rwkv6_cuda_kernel.forward(Batch, SequenceLength, HiddenSize, NumHeads, receptance, key, value, e_time_decay, time_first, out)
110
- return out, state
111
-
112
- @staticmethod
113
- def backward(ctx, g_out, g_state):
114
- with torch.no_grad():
115
- assert g_out.dtype == torch.bfloat16
116
- Batch = ctx.Batch
117
- SequenceLength = ctx.SequenceLength
118
- HiddenSize = ctx.HiddenSize
119
- NumHeads = ctx.NumHeads
120
- HeadSize = HiddenSize // NumHeads
121
- assert g_out.is_contiguous()
122
- receptance, key, value, e_time_decay, time_first = ctx.saved_tensors
123
- g_receptance = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
124
- g_key = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
125
- g_value = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
126
- g_time_decay = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
127
- g_time_first = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
128
- #gs = torch.empty((B, C//H, H, H), device=gy.device, requires_grad=False, dtype=torch.float, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
129
- rwkv6_cuda_kernel.backward(B, T, C, H, receptance, key, value, e_time_decay, time_first, g_out, g_receptance, g_key, g_value, g_time_decay, g_time_first)
130
- g_time_first = torch.sum(g_time_first, 0).view(NumHeads, HeadSize)
131
- return (None, None, None, None, g_receptance, g_key, g_value, g_time_decay, g_time_first, None)
132
-
133
  def rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
134
- input_dtype = receptance.dtype
135
  # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
136
  # within a torch.no_grad.
137
- batch, seq_length, hidden_size = receptance.shape
138
  num_heads, head_size = time_first.shape
139
  key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
140
  value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
141
  receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
142
- time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1) # B, H, S, T
143
  time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
144
  out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
145
 
@@ -168,24 +88,26 @@ def rwkv6_linear_attention(
168
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
169
  # in this case).
170
  one_token = key.size(1) == 1
171
- if not training or rwkv6_cuda_kernel is None or no_cuda or one_token:
172
  return rwkv6_linear_attention_cpu(
173
  receptance, key, value, time_decay, time_first, state
174
  )
175
  else:
176
- return Rwkv6LinearAttention.apply(receptance, key, value, time_decay, time_first, state)
 
 
 
 
 
 
 
 
177
 
178
 
179
  class Rwkv6SelfAttention(nn.Module):
180
  def __init__(self, config, layer_id=0):
181
  super().__init__()
182
  self.config = config
183
- kernel_loaded = rwkv6_cuda_kernel is not None and rwkv6_cuda_kernel.head_size == config.head_size
184
- if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
185
- try:
186
- load_wkv6_cuda_kernel(config.head_size, config.max_context_length) # FIXME - context_length is not a configured attribute
187
- except Exception:
188
- logger.info("Could not load the custom CUDA kernel for RWKV6 attention.")
189
  self.layer_id = layer_id
190
  hidden_size = config.hidden_size
191
  attention_hidden_size = config.attention_hidden_size
 
37
  )
38
 
39
  from .configuration_rwkv6 import Rwkv6Config
40
+ try:
41
+ from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
42
+ except ImportError:
43
+ print("Required module is not installed. Please install it using the following commands:")
44
+ print("pip install -U git+https://github.com/sustcsonglin/flash-linear-attention")
45
+ print("Additionally, ensure you have the correct version of Triton installed:")
46
+ print("pip install triton==2.2.0")
47
 
48
 
49
  logger = logging.get_logger(__name__)
 
51
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-6-world-1b6"
52
  _CONFIG_FOR_DOC = "Rwkv6Config"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
 
55
  # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
56
  # within a torch.no_grad.
57
+ batch, seq_length, _ = receptance.shape
58
  num_heads, head_size = time_first.shape
59
  key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
60
  value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
61
  receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
62
+ time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1)
63
  time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
64
  out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
65
 
 
88
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
89
  # in this case).
90
  one_token = key.size(1) == 1
91
+ if not training or no_cuda or one_token:
92
  return rwkv6_linear_attention_cpu(
93
  receptance, key, value, time_decay, time_first, state
94
  )
95
  else:
96
+ batch, seq_length, _ = receptance.shape
97
+ num_heads, head_size = time_first.shape
98
+ key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K -> B, H, T, K
99
+ value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K - > B, H, T, V
100
+ receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, H, T, K
101
+ time_decay = -torch.exp(time_decay.float()).view(batch, seq_length, num_heads, head_size).permute(0, 2, 1, 3) # B, T, H, K -> B, H, T, K
102
+ time_first = time_first.float().reshape(num_heads, head_size) # H, K
103
+ out, state = fused_recurrent_rwkv6(receptance, key, value, time_decay, time_first, scale=1.0, initial_state=state, output_final_state=True)
104
+ return out.transpose(1, 2), state
105
 
106
 
107
  class Rwkv6SelfAttention(nn.Module):
108
  def __init__(self, config, layer_id=0):
109
  super().__init__()
110
  self.config = config
 
 
 
 
 
 
111
  self.layer_id = layer_id
112
  hidden_size = config.hidden_size
113
  attention_hidden_size = config.attention_hidden_size