BBuf commited on
Commit
407fecd
1 Parent(s): 8fafa0d

Upload 11 files

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. configuration_rwkv5.py +2 -0
  3. cpp_kernels.py +55 -0
  4. modeling_rwkv5.py +41 -45
config.json CHANGED
@@ -21,5 +21,6 @@
21
  "tie_word_embeddings": false,
22
  "transformers_version": "4.33.1",
23
  "use_cache": true,
 
24
  "vocab_size": 65536
25
  }
 
21
  "tie_word_embeddings": false,
22
  "transformers_version": "4.33.1",
23
  "use_cache": true,
24
+ "use_cache_kernel": true,
25
  "vocab_size": 65536
26
  }
configuration_rwkv5.py CHANGED
@@ -101,6 +101,7 @@ class Rwkv5Config(PretrainedConfig):
101
  eos_token_id=0,
102
  rescale_every=6,
103
  tie_word_embeddings=False,
 
104
  use_cache=True,
105
  model_version="5_2",
106
  **kwargs,
@@ -114,6 +115,7 @@ class Rwkv5Config(PretrainedConfig):
114
  self.intermediate_size = None
115
  self.layer_norm_epsilon = layer_norm_epsilon
116
  self.rescale_every = rescale_every
 
117
  self.use_cache = use_cache
118
 
119
  self.bos_token_id = bos_token_id
 
101
  eos_token_id=0,
102
  rescale_every=6,
103
  tie_word_embeddings=False,
104
+ use_cache_kernel=True,
105
  use_cache=True,
106
  model_version="5_2",
107
  **kwargs,
 
115
  self.intermediate_size = None
116
  self.layer_norm_epsilon = layer_norm_epsilon
117
  self.rescale_every = rescale_every
118
+ self.use_cache_kernel = use_cache_kernel
119
  self.use_cache = use_cache
120
 
121
  self.bos_token_id = bos_token_id
cpp_kernels.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import cpp_extension
2
+ import pathlib
3
+ import os
4
+ import subprocess
5
+
6
+ def _get_cuda_bare_metal_version(cuda_dir):
7
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
8
+ universal_newlines=True)
9
+ output = raw_output.split()
10
+ release_idx = output.index("release") + 1
11
+ release = output[release_idx].split(".")
12
+ bare_metal_major = release[0]
13
+ bare_metal_minor = release[1][0]
14
+
15
+ return raw_output, bare_metal_major, bare_metal_minor
16
+
17
+ def _create_build_dir(buildpath):
18
+ try:
19
+ os.mkdir(buildpath)
20
+ except OSError:
21
+ if not os.path.isdir(buildpath):
22
+ print(f"Creation of the build directory {buildpath} failed")
23
+
24
+ # Check if cuda 11 is installed for compute capability 8.0
25
+ cc_flag = []
26
+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
27
+ if int(bare_metal_major) >= 11:
28
+ cc_flag.append('-gencode')
29
+ cc_flag.append('arch=compute_80,code=sm_80')
30
+ if int(bare_metal_minor) >= 7:
31
+ cc_flag.append('-gencode')
32
+ cc_flag.append('arch=compute_90,code=sm_90')
33
+
34
+ # Build path
35
+ srcpath = pathlib.Path(__file__).parent.absolute()
36
+ buildpath = srcpath / 'build'
37
+ _create_build_dir(buildpath)
38
+
39
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
40
+ return cpp_extension.load(
41
+ name=name,
42
+ sources=sources,
43
+ build_directory=buildpath,
44
+ extra_cflags=['-O3', ],
45
+ extra_cuda_cflags=['-O3',
46
+ '-gencode', 'arch=compute_70,code=sm_70',
47
+ '--use_fast_math'] + extra_cuda_flags + cc_flag,
48
+ verbose=1
49
+ )
50
+
51
+ extra_flags = []
52
+
53
+ cache_wkv5_sources = ["./rwkv5_op.cpp",
54
+ "./rwkv5.cu"]
55
+ cache_wkv5 = _cpp_extention_load_helper("cache_wkv5", cache_wkv5_sources, extra_flags)
modeling_rwkv5.py CHANGED
@@ -36,6 +36,7 @@ from transformers.utils import (
36
  logging,
37
  )
38
  from .configuration_rwkv5 import Rwkv5Config
 
39
  logger = logging.get_logger(__name__)
40
 
41
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
@@ -45,42 +46,29 @@ RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
 
46
  ]
47
 
48
- rwkv5_cuda_kernel = None
49
-
50
- def load_wkv5_cuda_kernel(config):
51
- global rwkv5_cuda_kernel
52
- if config.model_version == "5_2" and torch.cuda.is_available():
53
- HEAD_SIZE = args.attention_hidden_size // args.head_size
54
- module_root = pathlib.Path(__file__).parent
55
- rwkv5_cuda_kernel = load(name="rwkv5", sources=[f"{module_root}/rwkv5_op.cpp", f"{module_root}/rwkv5.cu"],
56
- verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
57
-
58
- class RWKV_5(torch.autograd.Function):
59
- @staticmethod
60
- def forward(ctx, B, T, C, H, state, r, k, v, w, u):
61
- with torch.no_grad():
62
- assert HEAD_SIZE == C // H
63
- ctx.B = B
64
- ctx.T = T
65
- ctx.C = C
66
- ctx.H = H
67
- assert state.dtype == torch.float32
68
- assert w.dtype == torch.float32
69
- assert r.is_contiguous()
70
- assert k.is_contiguous()
71
- assert v.is_contiguous()
72
- assert w.is_contiguous()
73
- assert u.is_contiguous()
74
- assert state.is_contiguous()
75
-
76
- y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
77
- if r.dtype == torch.bfloat16:
78
- rwkv5_cuda_kernel.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
79
- elif r.dtype == torch.float16:
80
- rwkv5_cuda_kernel.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
81
- elif r.dtype == torch.float32:
82
- rwkv5_cuda_kernel.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
83
- return y, state
84
 
85
  def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
86
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
@@ -153,12 +141,20 @@ class RwkvSelfAttention(nn.Module):
153
  super().__init__()
154
  self.config = config
155
  self.layer_id = layer_id
156
- kernel_loaded = rwkv5_cuda_kernel is not None
157
- if torch.cuda.is_available() and not kernel_loaded:
158
- try:
159
- load_wkv5_cuda_kernel(config)
160
- except Exception:
161
- logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
 
 
 
 
 
 
 
 
162
  self.hidden_size = config.hidden_size
163
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
164
  num_attention_heads = self.hidden_size // config.head_size
@@ -206,7 +202,7 @@ class RwkvSelfAttention(nn.Module):
206
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
207
  gate = F.silu(self.gate(gate))
208
 
209
- if rwkv5_cuda_kernel is None:
210
  if hidden.size(1) == 1 and state is not None:
211
  receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
212
  key = self.key(key).to(torch.float32).view(H, S, 1)
@@ -235,8 +231,8 @@ class RwkvSelfAttention(nn.Module):
235
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
236
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
237
  if self.config.model_version == "5_2":
238
- if rwkv5_cuda_kernel is not None and seq_mode:
239
- rwkv, layer_state = RWKV_5.apply(1, T, self.hidden_size, H, layer_state.transpose(-1, -2).contiguous(),
240
  receptance, key, value, self.time_decay, self.time_faaaa,)
241
  layer_state = layer_state.transpose(-1,-2)
242
  rwkv = rwkv.reshape(T, H*N)
 
36
  logging,
37
  )
38
  from .configuration_rwkv5 import Rwkv5Config
39
+ from .cpp_kernels import cache_wkv5
40
  logger = logging.get_logger(__name__)
41
 
42
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
 
46
 
47
  ]
48
 
49
+ def rwkv_linear_attention_v5_2_cuda(B, T, C, H, state, r, k, v, w, u, cache_kernels):
50
+ assert HEAD_SIZE == C // H
51
+ ctx.B = B
52
+ ctx.T = T
53
+ ctx.C = C
54
+ ctx.H = H
55
+ assert state.dtype == torch.float32
56
+ assert w.dtype == torch.float32
57
+ assert r.is_contiguous()
58
+ assert k.is_contiguous()
59
+ assert v.is_contiguous()
60
+ assert w.is_contiguous()
61
+ assert u.is_contiguous()
62
+ assert state.is_contiguous()
63
+
64
+ y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
65
+ if r.dtype == torch.bfloat16:
66
+ cache_kernels.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
67
+ elif r.dtype == torch.float16:
68
+ cache_kernels.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
69
+ elif r.dtype == torch.float32:
70
+ cache_kernels.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
71
+ return y, state
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
74
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
 
141
  super().__init__()
142
  self.config = config
143
  self.layer_id = layer_id
144
+ if config.use_cache_kernel:
145
+ # pre check if the support files existing
146
+ module_root = pathlib.Path(__file__).parent
147
+ src_files = ("rwkv5_op.cpp", "rwkv5.cu")
148
+ if any(not (module_root/src).is_file() for src in src_files):
149
+ warnings.warn("State cache kernel source files (.cpp and .cu) not found.")
150
+ self.cache_kernels = None
151
+ else:
152
+ try:
153
+ from .cpp_kernels import cache_wkv5
154
+ self.cache_kernels = cache_wkv5
155
+ except ImportError:
156
+ warnings.warn("Failed to import KV cache kernels.")
157
+ self.cache_kernels = None
158
  self.hidden_size = config.hidden_size
159
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
160
  num_attention_heads = self.hidden_size // config.head_size
 
202
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
203
  gate = F.silu(self.gate(gate))
204
 
205
+ if self.cache_kernels is None:
206
  if hidden.size(1) == 1 and state is not None:
207
  receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
208
  key = self.key(key).to(torch.float32).view(H, S, 1)
 
231
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
232
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
233
  if self.config.model_version == "5_2":
234
+ if self.cache_kernels is not None and seq_mode:
235
+ rwkv, layer_state = rwkv_linear_attention_v5_2_cuda(1, T, self.hidden_size, H, layer_state.transpose(-1, -2).contiguous(),
236
  receptance, key, value, self.time_decay, self.time_faaaa,)
237
  layer_state = layer_state.transpose(-1,-2)
238
  rwkv = rwkv.reshape(T, H*N)