Upload 11 files
Browse files- config.json +1 -0
- configuration_rwkv5.py +2 -0
- cpp_kernels.py +55 -0
- modeling_rwkv5.py +41 -45
config.json
CHANGED
@@ -21,5 +21,6 @@
|
|
21 |
"tie_word_embeddings": false,
|
22 |
"transformers_version": "4.33.1",
|
23 |
"use_cache": true,
|
|
|
24 |
"vocab_size": 65536
|
25 |
}
|
|
|
21 |
"tie_word_embeddings": false,
|
22 |
"transformers_version": "4.33.1",
|
23 |
"use_cache": true,
|
24 |
+
"use_cache_kernel": true,
|
25 |
"vocab_size": 65536
|
26 |
}
|
configuration_rwkv5.py
CHANGED
@@ -101,6 +101,7 @@ class Rwkv5Config(PretrainedConfig):
|
|
101 |
eos_token_id=0,
|
102 |
rescale_every=6,
|
103 |
tie_word_embeddings=False,
|
|
|
104 |
use_cache=True,
|
105 |
model_version="5_2",
|
106 |
**kwargs,
|
@@ -114,6 +115,7 @@ class Rwkv5Config(PretrainedConfig):
|
|
114 |
self.intermediate_size = None
|
115 |
self.layer_norm_epsilon = layer_norm_epsilon
|
116 |
self.rescale_every = rescale_every
|
|
|
117 |
self.use_cache = use_cache
|
118 |
|
119 |
self.bos_token_id = bos_token_id
|
|
|
101 |
eos_token_id=0,
|
102 |
rescale_every=6,
|
103 |
tie_word_embeddings=False,
|
104 |
+
use_cache_kernel=True,
|
105 |
use_cache=True,
|
106 |
model_version="5_2",
|
107 |
**kwargs,
|
|
|
115 |
self.intermediate_size = None
|
116 |
self.layer_norm_epsilon = layer_norm_epsilon
|
117 |
self.rescale_every = rescale_every
|
118 |
+
self.use_cache_kernel = use_cache_kernel
|
119 |
self.use_cache = use_cache
|
120 |
|
121 |
self.bos_token_id = bos_token_id
|
cpp_kernels.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils import cpp_extension
|
2 |
+
import pathlib
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
7 |
+
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
|
8 |
+
universal_newlines=True)
|
9 |
+
output = raw_output.split()
|
10 |
+
release_idx = output.index("release") + 1
|
11 |
+
release = output[release_idx].split(".")
|
12 |
+
bare_metal_major = release[0]
|
13 |
+
bare_metal_minor = release[1][0]
|
14 |
+
|
15 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
16 |
+
|
17 |
+
def _create_build_dir(buildpath):
|
18 |
+
try:
|
19 |
+
os.mkdir(buildpath)
|
20 |
+
except OSError:
|
21 |
+
if not os.path.isdir(buildpath):
|
22 |
+
print(f"Creation of the build directory {buildpath} failed")
|
23 |
+
|
24 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
25 |
+
cc_flag = []
|
26 |
+
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
27 |
+
if int(bare_metal_major) >= 11:
|
28 |
+
cc_flag.append('-gencode')
|
29 |
+
cc_flag.append('arch=compute_80,code=sm_80')
|
30 |
+
if int(bare_metal_minor) >= 7:
|
31 |
+
cc_flag.append('-gencode')
|
32 |
+
cc_flag.append('arch=compute_90,code=sm_90')
|
33 |
+
|
34 |
+
# Build path
|
35 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
36 |
+
buildpath = srcpath / 'build'
|
37 |
+
_create_build_dir(buildpath)
|
38 |
+
|
39 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
40 |
+
return cpp_extension.load(
|
41 |
+
name=name,
|
42 |
+
sources=sources,
|
43 |
+
build_directory=buildpath,
|
44 |
+
extra_cflags=['-O3', ],
|
45 |
+
extra_cuda_cflags=['-O3',
|
46 |
+
'-gencode', 'arch=compute_70,code=sm_70',
|
47 |
+
'--use_fast_math'] + extra_cuda_flags + cc_flag,
|
48 |
+
verbose=1
|
49 |
+
)
|
50 |
+
|
51 |
+
extra_flags = []
|
52 |
+
|
53 |
+
cache_wkv5_sources = ["./rwkv5_op.cpp",
|
54 |
+
"./rwkv5.cu"]
|
55 |
+
cache_wkv5 = _cpp_extention_load_helper("cache_wkv5", cache_wkv5_sources, extra_flags)
|
modeling_rwkv5.py
CHANGED
@@ -36,6 +36,7 @@ from transformers.utils import (
|
|
36 |
logging,
|
37 |
)
|
38 |
from .configuration_rwkv5 import Rwkv5Config
|
|
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
41 |
_CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
|
@@ -45,42 +46,29 @@ RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
45 |
|
46 |
]
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
assert v.is_contiguous()
|
72 |
-
assert w.is_contiguous()
|
73 |
-
assert u.is_contiguous()
|
74 |
-
assert state.is_contiguous()
|
75 |
-
|
76 |
-
y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
|
77 |
-
if r.dtype == torch.bfloat16:
|
78 |
-
rwkv5_cuda_kernel.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
|
79 |
-
elif r.dtype == torch.float16:
|
80 |
-
rwkv5_cuda_kernel.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
|
81 |
-
elif r.dtype == torch.float32:
|
82 |
-
rwkv5_cuda_kernel.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
|
83 |
-
return y, state
|
84 |
|
85 |
def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
|
86 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
|
@@ -153,12 +141,20 @@ class RwkvSelfAttention(nn.Module):
|
|
153 |
super().__init__()
|
154 |
self.config = config
|
155 |
self.layer_id = layer_id
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
self.hidden_size = config.hidden_size
|
163 |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
|
164 |
num_attention_heads = self.hidden_size // config.head_size
|
@@ -206,7 +202,7 @@ class RwkvSelfAttention(nn.Module):
|
|
206 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
207 |
gate = F.silu(self.gate(gate))
|
208 |
|
209 |
-
if
|
210 |
if hidden.size(1) == 1 and state is not None:
|
211 |
receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
|
212 |
key = self.key(key).to(torch.float32).view(H, S, 1)
|
@@ -235,8 +231,8 @@ class RwkvSelfAttention(nn.Module):
|
|
235 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
236 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
237 |
if self.config.model_version == "5_2":
|
238 |
-
if
|
239 |
-
rwkv, layer_state =
|
240 |
receptance, key, value, self.time_decay, self.time_faaaa,)
|
241 |
layer_state = layer_state.transpose(-1,-2)
|
242 |
rwkv = rwkv.reshape(T, H*N)
|
|
|
36 |
logging,
|
37 |
)
|
38 |
from .configuration_rwkv5 import Rwkv5Config
|
39 |
+
from .cpp_kernels import cache_wkv5
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
42 |
_CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
|
|
|
46 |
|
47 |
]
|
48 |
|
49 |
+
def rwkv_linear_attention_v5_2_cuda(B, T, C, H, state, r, k, v, w, u, cache_kernels):
|
50 |
+
assert HEAD_SIZE == C // H
|
51 |
+
ctx.B = B
|
52 |
+
ctx.T = T
|
53 |
+
ctx.C = C
|
54 |
+
ctx.H = H
|
55 |
+
assert state.dtype == torch.float32
|
56 |
+
assert w.dtype == torch.float32
|
57 |
+
assert r.is_contiguous()
|
58 |
+
assert k.is_contiguous()
|
59 |
+
assert v.is_contiguous()
|
60 |
+
assert w.is_contiguous()
|
61 |
+
assert u.is_contiguous()
|
62 |
+
assert state.is_contiguous()
|
63 |
+
|
64 |
+
y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
|
65 |
+
if r.dtype == torch.bfloat16:
|
66 |
+
cache_kernels.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
|
67 |
+
elif r.dtype == torch.float16:
|
68 |
+
cache_kernels.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
|
69 |
+
elif r.dtype == torch.float32:
|
70 |
+
cache_kernels.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
|
71 |
+
return y, state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
|
74 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
|
|
|
141 |
super().__init__()
|
142 |
self.config = config
|
143 |
self.layer_id = layer_id
|
144 |
+
if config.use_cache_kernel:
|
145 |
+
# pre check if the support files existing
|
146 |
+
module_root = pathlib.Path(__file__).parent
|
147 |
+
src_files = ("rwkv5_op.cpp", "rwkv5.cu")
|
148 |
+
if any(not (module_root/src).is_file() for src in src_files):
|
149 |
+
warnings.warn("State cache kernel source files (.cpp and .cu) not found.")
|
150 |
+
self.cache_kernels = None
|
151 |
+
else:
|
152 |
+
try:
|
153 |
+
from .cpp_kernels import cache_wkv5
|
154 |
+
self.cache_kernels = cache_wkv5
|
155 |
+
except ImportError:
|
156 |
+
warnings.warn("Failed to import KV cache kernels.")
|
157 |
+
self.cache_kernels = None
|
158 |
self.hidden_size = config.hidden_size
|
159 |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
|
160 |
num_attention_heads = self.hidden_size // config.head_size
|
|
|
202 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
203 |
gate = F.silu(self.gate(gate))
|
204 |
|
205 |
+
if self.cache_kernels is None:
|
206 |
if hidden.size(1) == 1 and state is not None:
|
207 |
receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
|
208 |
key = self.key(key).to(torch.float32).view(H, S, 1)
|
|
|
231 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
232 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
233 |
if self.config.model_version == "5_2":
|
234 |
+
if self.cache_kernels is not None and seq_mode:
|
235 |
+
rwkv, layer_state = rwkv_linear_attention_v5_2_cuda(1, T, self.hidden_size, H, layer_state.transpose(-1, -2).contiguous(),
|
236 |
receptance, key, value, self.time_decay, self.time_faaaa,)
|
237 |
layer_state = layer_state.transpose(-1,-2)
|
238 |
rwkv = rwkv.reshape(T, H*N)
|