Upload 10 files
Browse files- .gitattributes +1 -0
- config.json +37 -0
- gptq.py +163 -0
- modeling_chatsakura.py +26 -0
- modelutils.py +11 -0
- quant.py +305 -0
- quant_cuda.cpp +70 -0
- quant_cuda_kernel.cu +524 -0
- special_tokens_map.json +6 -0
- tokenizer.json +3 -0
- tokenizer_config.json +11 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "bigscience/bloomz-3b",
|
3 |
+
"apply_residual_connection_post_layernorm": false,
|
4 |
+
"architectures": [
|
5 |
+
"BloomForCausalLM"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoModel": "modeling_chatsakura.SakuraForCausalLM",
|
9 |
+
"AutoModelForCausalLM": "modeling_chatsakura.SakuraForCausalLM"
|
10 |
+
},
|
11 |
+
"attention_dropout": 0.0,
|
12 |
+
"attention_softmax_in_fp32": true,
|
13 |
+
"bias_dropout_fusion": true,
|
14 |
+
"bos_token_id": 1,
|
15 |
+
"eos_token_id": 2,
|
16 |
+
"hidden_dropout": 0.0,
|
17 |
+
"hidden_size": 2560,
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"layer_norm_epsilon": 1e-05,
|
20 |
+
"masked_softmax_fusion": true,
|
21 |
+
"model_type": "bloom",
|
22 |
+
"n_head": 32,
|
23 |
+
"n_inner": null,
|
24 |
+
"n_layer": 30,
|
25 |
+
"offset_alibi": 100,
|
26 |
+
"pad_token_id": 3,
|
27 |
+
"pretraining_tp": 4,
|
28 |
+
"seq_length": 2048,
|
29 |
+
"skip_bias_add": true,
|
30 |
+
"skip_bias_add_qkv": false,
|
31 |
+
"slow_but_exact": false,
|
32 |
+
"torch_dtype": "float16",
|
33 |
+
"transformers_version": "4.27.1",
|
34 |
+
"unk_token_id": 0,
|
35 |
+
"use_cache": true,
|
36 |
+
"vocab_size": 250880
|
37 |
+
}
|
gptq.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import transformers
|
7 |
+
|
8 |
+
from quant import *
|
9 |
+
|
10 |
+
|
11 |
+
DEBUG = False
|
12 |
+
|
13 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
14 |
+
torch.backends.cudnn.allow_tf32 = False
|
15 |
+
|
16 |
+
|
17 |
+
class GPTQ:
|
18 |
+
def __init__(self, layer):
|
19 |
+
self.layer = layer
|
20 |
+
self.dev = self.layer.weight.device
|
21 |
+
W = layer.weight.data.clone()
|
22 |
+
if isinstance(self.layer, nn.Conv2d):
|
23 |
+
W = W.flatten(1)
|
24 |
+
if isinstance(self.layer, transformers.Conv1D):
|
25 |
+
W = W.t()
|
26 |
+
self.rows = W.shape[0]
|
27 |
+
self.columns = W.shape[1]
|
28 |
+
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
29 |
+
self.nsamples = 0
|
30 |
+
|
31 |
+
def add_batch(self, inp, out):
|
32 |
+
if DEBUG:
|
33 |
+
self.inp1 = inp
|
34 |
+
self.out1 = out
|
35 |
+
if len(inp.shape) == 2:
|
36 |
+
inp = inp.unsqueeze(0)
|
37 |
+
tmp = inp.shape[0]
|
38 |
+
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
39 |
+
if len(inp.shape) == 3:
|
40 |
+
inp = inp.reshape((-1, inp.shape[-1]))
|
41 |
+
inp = inp.t()
|
42 |
+
if isinstance(self.layer, nn.Conv2d):
|
43 |
+
unfold = nn.Unfold(
|
44 |
+
self.layer.kernel_size,
|
45 |
+
dilation=self.layer.dilation,
|
46 |
+
padding=self.layer.padding,
|
47 |
+
stride=self.layer.stride
|
48 |
+
)
|
49 |
+
inp = unfold(inp)
|
50 |
+
inp = inp.permute([1, 0, 2])
|
51 |
+
inp = inp.flatten(1)
|
52 |
+
self.H *= self.nsamples / (self.nsamples + tmp)
|
53 |
+
self.nsamples += tmp
|
54 |
+
# inp = inp.float()
|
55 |
+
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
56 |
+
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
57 |
+
self.H += inp.matmul(inp.t())
|
58 |
+
|
59 |
+
def fasterquant(
|
60 |
+
self, blocksize=128, percdamp=.01, groupsize=-1
|
61 |
+
):
|
62 |
+
W = self.layer.weight.data.clone()
|
63 |
+
if isinstance(self.layer, nn.Conv2d):
|
64 |
+
W = W.flatten(1)
|
65 |
+
if isinstance(self.layer, transformers.Conv1D):
|
66 |
+
W = W.t()
|
67 |
+
W = W.float()
|
68 |
+
|
69 |
+
tick = time.time()
|
70 |
+
|
71 |
+
if not self.quantizer.ready():
|
72 |
+
self.quantizer.find_params(W, weight=True)
|
73 |
+
|
74 |
+
H = self.H
|
75 |
+
del self.H
|
76 |
+
dead = torch.diag(H) == 0
|
77 |
+
H[dead, dead] = 1
|
78 |
+
W[:, dead] = 0
|
79 |
+
|
80 |
+
Losses = torch.zeros_like(W)
|
81 |
+
Q = torch.zeros_like(W)
|
82 |
+
|
83 |
+
damp = percdamp * torch.mean(torch.diag(H))
|
84 |
+
diag = torch.arange(self.columns, device=self.dev)
|
85 |
+
H[diag, diag] += damp
|
86 |
+
H = torch.linalg.cholesky(H)
|
87 |
+
H = torch.cholesky_inverse(H)
|
88 |
+
H = torch.linalg.cholesky(H, upper=True)
|
89 |
+
Hinv = H
|
90 |
+
|
91 |
+
scale = []
|
92 |
+
zero = []
|
93 |
+
now_idx = 1
|
94 |
+
|
95 |
+
for i1 in range(0, self.columns, blocksize):
|
96 |
+
i2 = min(i1 + blocksize, self.columns)
|
97 |
+
count = i2 - i1
|
98 |
+
|
99 |
+
W1 = W[:, i1:i2].clone()
|
100 |
+
Q1 = torch.zeros_like(W1)
|
101 |
+
Err1 = torch.zeros_like(W1)
|
102 |
+
Losses1 = torch.zeros_like(W1)
|
103 |
+
Hinv1 = Hinv[i1:i2, i1:i2]
|
104 |
+
|
105 |
+
for i in range(count):
|
106 |
+
w = W1[:, i]
|
107 |
+
d = Hinv1[i, i]
|
108 |
+
|
109 |
+
if groupsize != -1:
|
110 |
+
if (i1 + i) % groupsize == 0:
|
111 |
+
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
112 |
+
|
113 |
+
if ((i1 + i) // groupsize) - now_idx == -1:
|
114 |
+
scale.append(self.quantizer.scale)
|
115 |
+
zero.append(self.quantizer.zero)
|
116 |
+
now_idx += 1
|
117 |
+
|
118 |
+
q = quantize(
|
119 |
+
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
|
120 |
+
).flatten()
|
121 |
+
Q1[:, i] = q
|
122 |
+
Losses1[:, i] = (w - q) ** 2 / d ** 2
|
123 |
+
|
124 |
+
err1 = (w - q) / d
|
125 |
+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
126 |
+
Err1[:, i] = err1
|
127 |
+
|
128 |
+
Q[:, i1:i2] = Q1
|
129 |
+
Losses[:, i1:i2] = Losses1 / 2
|
130 |
+
|
131 |
+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
132 |
+
|
133 |
+
if DEBUG:
|
134 |
+
self.layer.weight.data[:, :i2] = Q[:, :i2]
|
135 |
+
self.layer.weight.data[:, i2:] = W[:, i2:]
|
136 |
+
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
|
137 |
+
print(torch.sum(Losses))
|
138 |
+
|
139 |
+
torch.cuda.synchronize()
|
140 |
+
print('time %.2f' % (time.time() - tick))
|
141 |
+
print('error', torch.sum(Losses).item())
|
142 |
+
|
143 |
+
if isinstance(self.layer, transformers.Conv1D):
|
144 |
+
Q = Q.t()
|
145 |
+
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
146 |
+
if DEBUG:
|
147 |
+
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
|
148 |
+
|
149 |
+
if scale == []:
|
150 |
+
scale.append(self.quantizer.scale)
|
151 |
+
zero.append(self.quantizer.zero)
|
152 |
+
scale = torch.cat(scale,dim=1)
|
153 |
+
zero = torch.cat(zero,dim=1)
|
154 |
+
return scale,zero
|
155 |
+
|
156 |
+
def free(self):
|
157 |
+
if DEBUG:
|
158 |
+
self.inp1 = None
|
159 |
+
self.out1 = None
|
160 |
+
self.H = None
|
161 |
+
self.Losses = None
|
162 |
+
self.Trace = None
|
163 |
+
torch.cuda.empty_cache()
|
modeling_chatsakura.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from gptq import *
|
5 |
+
from modelutils import *
|
6 |
+
from quant import *
|
7 |
+
from transformers import BloomForCausalLM as LM
|
8 |
+
|
9 |
+
class SakuraForCausalLM(LM):
|
10 |
+
def __init__(self,*args,**kwargs):
|
11 |
+
def noop(*args, **kwargs):
|
12 |
+
pass
|
13 |
+
torch.nn.init.kaiming_uniform_ = noop
|
14 |
+
torch.nn.init.uniform_ = noop
|
15 |
+
torch.nn.init.normal_ = noop
|
16 |
+
torch.set_default_dtype(torch.half)
|
17 |
+
transformers.modeling_utils._init_weights = False
|
18 |
+
torch.set_default_dtype(torch.half)
|
19 |
+
super().__init__(*args,**kwargs)
|
20 |
+
torch.set_default_dtype(torch.float)
|
21 |
+
self.eval()
|
22 |
+
layers = find_layers(self)
|
23 |
+
for name in ['lm_head']:
|
24 |
+
if name in layers:
|
25 |
+
del layers[name]
|
26 |
+
make_quant(self, layers, 8, 128)
|
modelutils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
4 |
+
if type(module) in layers:
|
5 |
+
return {name: module}
|
6 |
+
res = {}
|
7 |
+
for name1, child in module.named_children():
|
8 |
+
res.update(find_layers(
|
9 |
+
child, layers=layers, name=name + '.' + name1 if name != '' else name1
|
10 |
+
))
|
11 |
+
return res
|
quant.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
|
6 |
+
def quantize(x, scale, zero, maxq):
|
7 |
+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
8 |
+
return scale * (q - zero)
|
9 |
+
|
10 |
+
class Quantizer(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, shape=1):
|
13 |
+
super(Quantizer, self).__init__()
|
14 |
+
self.register_buffer('maxq', torch.tensor(0))
|
15 |
+
self.register_buffer('scale', torch.zeros(shape))
|
16 |
+
self.register_buffer('zero', torch.zeros(shape))
|
17 |
+
|
18 |
+
def configure(
|
19 |
+
self,
|
20 |
+
bits, perchannel=False, sym=True,
|
21 |
+
mse=False, norm=2.4, grid=100, maxshrink=.8
|
22 |
+
):
|
23 |
+
self.maxq = torch.tensor(2 ** bits - 1)
|
24 |
+
self.perchannel = perchannel
|
25 |
+
self.sym = sym
|
26 |
+
self.mse = mse
|
27 |
+
self.norm = norm
|
28 |
+
self.grid = grid
|
29 |
+
self.maxshrink = maxshrink
|
30 |
+
|
31 |
+
def find_params(self, x, weight=False):
|
32 |
+
dev = x.device
|
33 |
+
self.maxq = self.maxq.to(dev)
|
34 |
+
|
35 |
+
shape = x.shape
|
36 |
+
if self.perchannel:
|
37 |
+
if weight:
|
38 |
+
x = x.flatten(1)
|
39 |
+
else:
|
40 |
+
if len(shape) == 4:
|
41 |
+
x = x.permute([1, 0, 2, 3])
|
42 |
+
x = x.flatten(1)
|
43 |
+
if len(shape) == 3:
|
44 |
+
x = x.reshape((-1, shape[-1])).t()
|
45 |
+
if len(shape) == 2:
|
46 |
+
x = x.t()
|
47 |
+
else:
|
48 |
+
x = x.flatten().unsqueeze(0)
|
49 |
+
|
50 |
+
tmp = torch.zeros(x.shape[0], device=dev)
|
51 |
+
xmin = torch.minimum(x.min(1)[0], tmp)
|
52 |
+
xmax = torch.maximum(x.max(1)[0], tmp)
|
53 |
+
|
54 |
+
if self.sym:
|
55 |
+
xmax = torch.maximum(torch.abs(xmin), xmax)
|
56 |
+
tmp = xmin < 0
|
57 |
+
if torch.any(tmp):
|
58 |
+
xmin[tmp] = -xmax[tmp]
|
59 |
+
tmp = (xmin == 0) & (xmax == 0)
|
60 |
+
xmin[tmp] = -1
|
61 |
+
xmax[tmp] = +1
|
62 |
+
|
63 |
+
self.scale = (xmax - xmin) / self.maxq
|
64 |
+
if self.sym:
|
65 |
+
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
66 |
+
else:
|
67 |
+
self.zero = torch.round(-xmin / self.scale)
|
68 |
+
|
69 |
+
if self.mse:
|
70 |
+
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
71 |
+
for i in range(int(self.maxshrink * self.grid)):
|
72 |
+
p = 1 - i / self.grid
|
73 |
+
xmin1 = p * xmin
|
74 |
+
xmax1 = p * xmax
|
75 |
+
scale1 = (xmax1 - xmin1) / self.maxq
|
76 |
+
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
77 |
+
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
78 |
+
q -= x
|
79 |
+
q.abs_()
|
80 |
+
q.pow_(self.norm)
|
81 |
+
err = torch.sum(q, 1)
|
82 |
+
tmp = err < best
|
83 |
+
if torch.any(tmp):
|
84 |
+
best[tmp] = err[tmp]
|
85 |
+
self.scale[tmp] = scale1[tmp]
|
86 |
+
self.zero[tmp] = zero1[tmp]
|
87 |
+
if not self.perchannel:
|
88 |
+
if weight:
|
89 |
+
tmp = shape[0]
|
90 |
+
else:
|
91 |
+
tmp = shape[1] if len(shape) != 3 else shape[2]
|
92 |
+
self.scale = self.scale.repeat(tmp)
|
93 |
+
self.zero = self.zero.repeat(tmp)
|
94 |
+
|
95 |
+
if weight:
|
96 |
+
shape = [-1] + [1] * (len(shape) - 1)
|
97 |
+
self.scale = self.scale.reshape(shape)
|
98 |
+
self.zero = self.zero.reshape(shape)
|
99 |
+
return
|
100 |
+
if len(shape) == 4:
|
101 |
+
self.scale = self.scale.reshape((1, -1, 1, 1))
|
102 |
+
self.zero = self.zero.reshape((1, -1, 1, 1))
|
103 |
+
if len(shape) == 3:
|
104 |
+
self.scale = self.scale.reshape((1, 1, -1))
|
105 |
+
self.zero = self.zero.reshape((1, 1, -1))
|
106 |
+
if len(shape) == 2:
|
107 |
+
self.scale = self.scale.unsqueeze(0)
|
108 |
+
self.zero = self.zero.unsqueeze(0)
|
109 |
+
|
110 |
+
def quantize(self, x):
|
111 |
+
if self.ready():
|
112 |
+
return quantize(x, self.scale, self.zero, self.maxq)
|
113 |
+
return x
|
114 |
+
|
115 |
+
def enabled(self):
|
116 |
+
return self.maxq > 0
|
117 |
+
|
118 |
+
def ready(self):
|
119 |
+
return torch.all(self.scale != 0)
|
120 |
+
|
121 |
+
|
122 |
+
try:
|
123 |
+
import quant_cuda
|
124 |
+
except:
|
125 |
+
import os
|
126 |
+
import sys
|
127 |
+
argv = sys.argv
|
128 |
+
sys.argv = ['quant.py','install']
|
129 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
130 |
+
from setuptools import setup, Extension
|
131 |
+
from torch.utils import cpp_extension
|
132 |
+
os.chdir(dir_path)
|
133 |
+
setup(
|
134 |
+
name='quant_cuda',
|
135 |
+
ext_modules=[cpp_extension.CUDAExtension(
|
136 |
+
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
|
137 |
+
)],
|
138 |
+
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
139 |
+
)
|
140 |
+
os.chdir(os.getcwd())
|
141 |
+
sys.argv = argv
|
142 |
+
for i in sys.path:
|
143 |
+
if i.endswith("site-packages"):
|
144 |
+
for j in os.listdir(i):
|
145 |
+
if j.find("quant_cuda") != -1:
|
146 |
+
sys.path.append(os.path.join(i,j))
|
147 |
+
break
|
148 |
+
break
|
149 |
+
import quant_cuda
|
150 |
+
|
151 |
+
|
152 |
+
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
153 |
+
class QuantLinear(nn.Module):
|
154 |
+
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
155 |
+
super().__init__()
|
156 |
+
if bits not in [2,3,4,8]:
|
157 |
+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
158 |
+
self.infeatures = infeatures
|
159 |
+
self.outfeatures = outfeatures
|
160 |
+
self.bits = bits
|
161 |
+
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))):
|
162 |
+
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
|
163 |
+
groupsize = groupsize if groupsize != -1 else infeatures
|
164 |
+
self.groupsize = groupsize
|
165 |
+
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int))
|
166 |
+
self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures)))
|
167 |
+
self.register_buffer('bias', torch.zeros(outfeatures))
|
168 |
+
self.register_buffer(
|
169 |
+
'qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
|
170 |
+
)
|
171 |
+
self._initialized_quant_state = False
|
172 |
+
|
173 |
+
def pack(self, linear, scales, zeros):
|
174 |
+
scales = scales.t().contiguous()
|
175 |
+
zeros = zeros.t().contiguous()
|
176 |
+
scale_zeros = zeros * scales
|
177 |
+
self.scales = scales.clone()
|
178 |
+
if linear.bias is not None:
|
179 |
+
self.bias = linear.bias.clone()
|
180 |
+
|
181 |
+
intweight = []
|
182 |
+
for idx in range(self.infeatures):
|
183 |
+
g_idx = idx // self.groupsize
|
184 |
+
intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None])
|
185 |
+
intweight = torch.cat(intweight,dim=1)
|
186 |
+
intweight = intweight.t().contiguous()
|
187 |
+
intweight = intweight.numpy().astype(np.uint32)
|
188 |
+
qweight = np.zeros(
|
189 |
+
(intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32
|
190 |
+
)
|
191 |
+
i = 0
|
192 |
+
row = 0
|
193 |
+
while row < qweight.shape[0]:
|
194 |
+
if self.bits in [2,4,8]:
|
195 |
+
for j in range(i, i + (32//self.bits)):
|
196 |
+
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
197 |
+
i += 32//self.bits
|
198 |
+
row += 1
|
199 |
+
elif self.bits == 3:
|
200 |
+
for j in range(i, i + 10):
|
201 |
+
qweight[row] |= intweight[j] << (3 * (j - i))
|
202 |
+
i += 10
|
203 |
+
qweight[row] |= intweight[i] << 30
|
204 |
+
row += 1
|
205 |
+
qweight[row] |= (intweight[i] >> 2) & 1
|
206 |
+
i += 1
|
207 |
+
for j in range(i, i + 10):
|
208 |
+
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
|
209 |
+
i += 10
|
210 |
+
qweight[row] |= intweight[i] << 31
|
211 |
+
row += 1
|
212 |
+
qweight[row] |= (intweight[i] >> 1) & 0x3
|
213 |
+
i += 1
|
214 |
+
for j in range(i, i + 10):
|
215 |
+
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
|
216 |
+
i += 10
|
217 |
+
row += 1
|
218 |
+
else:
|
219 |
+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
220 |
+
|
221 |
+
qweight = qweight.astype(np.int32)
|
222 |
+
self.qweight = torch.from_numpy(qweight)
|
223 |
+
|
224 |
+
zeros -= 1;
|
225 |
+
zeros = zeros.numpy().astype(np.uint32)
|
226 |
+
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
|
227 |
+
i = 0
|
228 |
+
col = 0
|
229 |
+
while col < qzeros.shape[1]:
|
230 |
+
if self.bits in [2,4,8]:
|
231 |
+
for j in range(i, i + (32//self.bits)):
|
232 |
+
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
233 |
+
i += 32//self.bits
|
234 |
+
col += 1
|
235 |
+
elif self.bits == 3:
|
236 |
+
for j in range(i, i + 10):
|
237 |
+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
|
238 |
+
i += 10
|
239 |
+
qzeros[:, col] |= zeros[:, i] << 30
|
240 |
+
col += 1
|
241 |
+
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
|
242 |
+
i += 1
|
243 |
+
for j in range(i, i + 10):
|
244 |
+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
|
245 |
+
i += 10
|
246 |
+
qzeros[:, col] |= zeros[:, i] << 31
|
247 |
+
col += 1
|
248 |
+
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
|
249 |
+
i += 1
|
250 |
+
for j in range(i, i + 10):
|
251 |
+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
|
252 |
+
i += 10
|
253 |
+
col += 1
|
254 |
+
else:
|
255 |
+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
256 |
+
|
257 |
+
qzeros = qzeros.astype(np.int32)
|
258 |
+
self.qzeros = torch.from_numpy(qzeros)
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
intermediate_dtype = torch.float32
|
262 |
+
|
263 |
+
if not self._initialized_quant_state:
|
264 |
+
# Do we even have a bias? Check for at least one non-zero element.
|
265 |
+
if self.bias is not None and bool(torch.any(self.bias != 0)):
|
266 |
+
# Then make sure it's the right type.
|
267 |
+
self.bias.data = self.bias.data.to(intermediate_dtype)
|
268 |
+
else:
|
269 |
+
self.bias = None
|
270 |
+
|
271 |
+
outshape = list(x.shape)
|
272 |
+
outshape[-1] = self.outfeatures
|
273 |
+
x = x.reshape(-1, x.shape[-1])
|
274 |
+
if self.bias is None:
|
275 |
+
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
|
276 |
+
else:
|
277 |
+
y = self.bias.clone().repeat(x.shape[0], 1)
|
278 |
+
|
279 |
+
output_dtype = x.dtype
|
280 |
+
x = x.to(intermediate_dtype)
|
281 |
+
if self.bits == 2:
|
282 |
+
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
283 |
+
elif self.bits == 3:
|
284 |
+
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
285 |
+
elif self.bits == 4:
|
286 |
+
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
287 |
+
elif self.bits == 8:
|
288 |
+
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
289 |
+
else:
|
290 |
+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
291 |
+
y = y.to(output_dtype)
|
292 |
+
return y.reshape(outshape)
|
293 |
+
|
294 |
+
def make_quant(module, names, bits, groupsize, name=''):
|
295 |
+
if isinstance(module, QuantLinear):
|
296 |
+
return
|
297 |
+
for attr in dir(module):
|
298 |
+
tmp = getattr(module, attr)
|
299 |
+
name1 = name + '.' + attr if name != '' else attr
|
300 |
+
if name1 in names:
|
301 |
+
setattr(
|
302 |
+
module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)
|
303 |
+
)
|
304 |
+
for name1, child in module.named_children():
|
305 |
+
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
quant_cuda.cpp
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <torch/python.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
void vecquant2matmul_cuda(
|
6 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
7 |
+
torch::Tensor scales, torch::Tensor zeros,
|
8 |
+
int groupsize
|
9 |
+
);
|
10 |
+
|
11 |
+
void vecquant2matmul(
|
12 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
13 |
+
torch::Tensor scales, torch::Tensor zeros,
|
14 |
+
int groupsize
|
15 |
+
) {
|
16 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
17 |
+
vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize);
|
18 |
+
}
|
19 |
+
|
20 |
+
void vecquant3matmul_cuda(
|
21 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
22 |
+
torch::Tensor scales, torch::Tensor zeros,
|
23 |
+
int groupsize
|
24 |
+
);
|
25 |
+
|
26 |
+
void vecquant3matmul(
|
27 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
28 |
+
torch::Tensor scales, torch::Tensor zeros,
|
29 |
+
int groupsize
|
30 |
+
) {
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
32 |
+
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
|
33 |
+
}
|
34 |
+
|
35 |
+
void vecquant4matmul_cuda(
|
36 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
37 |
+
torch::Tensor scales, torch::Tensor zeros,
|
38 |
+
int groupsize
|
39 |
+
);
|
40 |
+
|
41 |
+
void vecquant4matmul(
|
42 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
43 |
+
torch::Tensor scales, torch::Tensor zeros,
|
44 |
+
int groupsize
|
45 |
+
) {
|
46 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
47 |
+
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
|
48 |
+
}
|
49 |
+
|
50 |
+
void vecquant8matmul_cuda(
|
51 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
52 |
+
torch::Tensor scales, torch::Tensor zeros,
|
53 |
+
int groupsize
|
54 |
+
);
|
55 |
+
|
56 |
+
void vecquant8matmul(
|
57 |
+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
58 |
+
torch::Tensor scales, torch::Tensor zeros,
|
59 |
+
int groupsize
|
60 |
+
) {
|
61 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
62 |
+
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
|
63 |
+
}
|
64 |
+
|
65 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
66 |
+
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
67 |
+
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
68 |
+
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
69 |
+
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
70 |
+
}
|
quant_cuda_kernel.cu
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <torch/python.h>
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
|
6 |
+
// atomicAdd for double-precision floating-point numbers on hardware with
|
7 |
+
// compute capability < 6.0 from:
|
8 |
+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
|
9 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
|
10 |
+
__device__ double atomicAdd(
|
11 |
+
double* address,
|
12 |
+
double val
|
13 |
+
) {
|
14 |
+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
15 |
+
unsigned long long int old = *address_as_ull, assumed;
|
16 |
+
|
17 |
+
do {
|
18 |
+
assumed = old;
|
19 |
+
old = atomicCAS(
|
20 |
+
address_as_ull,
|
21 |
+
assumed,
|
22 |
+
__double_as_longlong(val + __longlong_as_double(assumed))
|
23 |
+
);
|
24 |
+
|
25 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
26 |
+
} while (assumed != old);
|
27 |
+
|
28 |
+
return __longlong_as_double(old);
|
29 |
+
}
|
30 |
+
#endif
|
31 |
+
|
32 |
+
template <typename scalar_t>
|
33 |
+
__global__ void VecQuant2MatMulKernel(
|
34 |
+
const scalar_t* __restrict__ vec,
|
35 |
+
const int* __restrict__ mat,
|
36 |
+
scalar_t* __restrict__ mul,
|
37 |
+
const scalar_t* __restrict__ scales,
|
38 |
+
const int* __restrict__ zeros,
|
39 |
+
int batch,
|
40 |
+
int vec_height,
|
41 |
+
int height,
|
42 |
+
int width,
|
43 |
+
int zero_width,
|
44 |
+
int groupsize
|
45 |
+
);
|
46 |
+
|
47 |
+
template <typename scalar_t>
|
48 |
+
__global__ void VecQuant3MatMulKernel(
|
49 |
+
const scalar_t* __restrict__ vec,
|
50 |
+
const int* __restrict__ mat,
|
51 |
+
scalar_t* __restrict__ mul,
|
52 |
+
const scalar_t* __restrict__ scales,
|
53 |
+
const int* __restrict__ zeros,
|
54 |
+
int batch,
|
55 |
+
int vec_height,
|
56 |
+
int height,
|
57 |
+
int width,
|
58 |
+
int zero_width,
|
59 |
+
int groupsize
|
60 |
+
);
|
61 |
+
|
62 |
+
template <typename scalar_t>
|
63 |
+
__global__ void VecQuant4MatMulKernel(
|
64 |
+
const scalar_t* __restrict__ vec,
|
65 |
+
const int* __restrict__ mat,
|
66 |
+
scalar_t* __restrict__ mul,
|
67 |
+
const scalar_t* __restrict__ scales,
|
68 |
+
const int* __restrict__ zeros,
|
69 |
+
int batch,
|
70 |
+
int vec_height,
|
71 |
+
int height,
|
72 |
+
int width,
|
73 |
+
int zero_width,
|
74 |
+
int groupsize
|
75 |
+
);
|
76 |
+
|
77 |
+
template <typename scalar_t>
|
78 |
+
__global__ void VecQuant8MatMulKernel(
|
79 |
+
const scalar_t* __restrict__ vec,
|
80 |
+
const int* __restrict__ mat,
|
81 |
+
scalar_t* __restrict__ mul,
|
82 |
+
const scalar_t* __restrict__ scales,
|
83 |
+
const int* __restrict__ zeros,
|
84 |
+
int batch,
|
85 |
+
int vec_height,
|
86 |
+
int height,
|
87 |
+
int width,
|
88 |
+
int zero_width,
|
89 |
+
int groupsize
|
90 |
+
);
|
91 |
+
|
92 |
+
const int BLOCKWIDTH = 256;
|
93 |
+
const int BLOCKHEIGHT2 = 16;
|
94 |
+
const int BLOCKHEIGHT3 = 24;
|
95 |
+
const int BLOCKHEIGHT4 = 32;
|
96 |
+
const int BLOCKHEIGHT8 = 64;
|
97 |
+
|
98 |
+
__device__ inline unsigned int as_unsigned(int i) {
|
99 |
+
return *reinterpret_cast<unsigned int*>(&i);
|
100 |
+
}
|
101 |
+
|
102 |
+
void vecquant2matmul_cuda(
|
103 |
+
torch::Tensor vec,
|
104 |
+
torch::Tensor mat,
|
105 |
+
torch::Tensor mul,
|
106 |
+
torch::Tensor scales,
|
107 |
+
torch::Tensor zeros,
|
108 |
+
int groupsize
|
109 |
+
) {
|
110 |
+
int batch = vec.size(0);
|
111 |
+
int vec_height = vec.size(1);
|
112 |
+
int height = mat.size(0);
|
113 |
+
int width = mat.size(1);
|
114 |
+
int zero_width = zeros.size(1);
|
115 |
+
|
116 |
+
dim3 blocks(
|
117 |
+
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
|
118 |
+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
119 |
+
batch
|
120 |
+
);
|
121 |
+
dim3 threads(BLOCKWIDTH);
|
122 |
+
|
123 |
+
AT_DISPATCH_FLOATING_TYPES(
|
124 |
+
vec.type(), "vecquant2matmul_cuda", ([&] {
|
125 |
+
VecQuant2MatMulKernel<<<blocks, threads>>>(
|
126 |
+
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
127 |
+
scales.data<scalar_t>(), zeros.data<int>(),
|
128 |
+
batch, vec_height, height, width, zero_width, groupsize
|
129 |
+
);
|
130 |
+
})
|
131 |
+
);
|
132 |
+
}
|
133 |
+
|
134 |
+
template <typename scalar_t>
|
135 |
+
__global__ void VecQuant2MatMulKernel(
|
136 |
+
const scalar_t* __restrict__ vec,
|
137 |
+
const int* __restrict__ mat,
|
138 |
+
scalar_t* __restrict__ mul,
|
139 |
+
const scalar_t* __restrict__ scales,
|
140 |
+
const int* __restrict__ zeros,
|
141 |
+
int batch,
|
142 |
+
int vec_height,
|
143 |
+
int height,
|
144 |
+
int width,
|
145 |
+
int zero_width,
|
146 |
+
int groupsize
|
147 |
+
) {
|
148 |
+
int b = blockIdx.z;
|
149 |
+
int h = BLOCKHEIGHT2 * blockIdx.x;
|
150 |
+
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
151 |
+
|
152 |
+
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
153 |
+
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
scalar_t res = 0;
|
157 |
+
int i = width * h + w;
|
158 |
+
int g_h = h * 16;
|
159 |
+
int k = 0;
|
160 |
+
|
161 |
+
int z_w = w / 16;
|
162 |
+
int z_mod = (w % 16) * 2;
|
163 |
+
|
164 |
+
unsigned int tmp;
|
165 |
+
|
166 |
+
while (k < BLOCKWIDTH) {
|
167 |
+
tmp = as_unsigned(mat[i]);
|
168 |
+
|
169 |
+
int g = (g_h + k) / groupsize;
|
170 |
+
scalar_t scale = scales[g * width + w];
|
171 |
+
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
172 |
+
|
173 |
+
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
174 |
+
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
175 |
+
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
|
176 |
+
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
|
177 |
+
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
|
178 |
+
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
|
179 |
+
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
|
180 |
+
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
|
181 |
+
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
|
182 |
+
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
|
183 |
+
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
|
184 |
+
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
|
185 |
+
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
|
186 |
+
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
|
187 |
+
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
|
188 |
+
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
|
189 |
+
|
190 |
+
i += width;
|
191 |
+
k += 16;
|
192 |
+
}
|
193 |
+
|
194 |
+
atomicAdd(&mul[b * width + w], res);
|
195 |
+
}
|
196 |
+
|
197 |
+
void vecquant3matmul_cuda(
|
198 |
+
torch::Tensor vec,
|
199 |
+
torch::Tensor mat,
|
200 |
+
torch::Tensor mul,
|
201 |
+
torch::Tensor scales,
|
202 |
+
torch::Tensor zeros,
|
203 |
+
int groupsize
|
204 |
+
) {
|
205 |
+
int batch = vec.size(0);
|
206 |
+
int vec_height = vec.size(1);
|
207 |
+
int height = mat.size(0);
|
208 |
+
int width = mat.size(1);
|
209 |
+
int zero_width = zeros.size(1);
|
210 |
+
|
211 |
+
dim3 blocks(
|
212 |
+
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
|
213 |
+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
214 |
+
batch
|
215 |
+
);
|
216 |
+
dim3 threads(BLOCKWIDTH);
|
217 |
+
|
218 |
+
AT_DISPATCH_FLOATING_TYPES(
|
219 |
+
vec.type(), "vecquant3matmul_cuda", ([&] {
|
220 |
+
VecQuant3MatMulKernel<<<blocks, threads>>>(
|
221 |
+
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
222 |
+
scales.data<scalar_t>(), zeros.data<int>(),
|
223 |
+
batch, vec_height, height, width, zero_width, groupsize
|
224 |
+
);
|
225 |
+
})
|
226 |
+
);
|
227 |
+
}
|
228 |
+
|
229 |
+
template <typename scalar_t>
|
230 |
+
__global__ void VecQuant3MatMulKernel(
|
231 |
+
const scalar_t* __restrict__ vec,
|
232 |
+
const int* __restrict__ mat,
|
233 |
+
scalar_t* __restrict__ mul,
|
234 |
+
const scalar_t* __restrict__ scales,
|
235 |
+
const int* __restrict__ zeros,
|
236 |
+
int batch,
|
237 |
+
int vec_height,
|
238 |
+
int height,
|
239 |
+
int width,
|
240 |
+
int zero_width,
|
241 |
+
int groupsize
|
242 |
+
) {
|
243 |
+
int b = blockIdx.z;
|
244 |
+
int h = BLOCKHEIGHT3 * blockIdx.x;
|
245 |
+
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
246 |
+
|
247 |
+
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
248 |
+
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
249 |
+
__syncthreads();
|
250 |
+
|
251 |
+
scalar_t res = 0;
|
252 |
+
int i = width * h + w;
|
253 |
+
int g_h = (h / 3) * 32;
|
254 |
+
int k = 0;
|
255 |
+
|
256 |
+
int z_w = (w / 32) * 3; // ((w / 256) * 24) / 3
|
257 |
+
int z_mod = w % 32;
|
258 |
+
int z_bit;
|
259 |
+
|
260 |
+
if (z_mod != 10){
|
261 |
+
if (z_mod != 21){
|
262 |
+
z_bit = z_mod;
|
263 |
+
if (z_bit > 21){
|
264 |
+
z_bit -= 22;
|
265 |
+
z_bit *= 3;
|
266 |
+
z_bit += 2;
|
267 |
+
z_w += 2;
|
268 |
+
} else if (z_bit > 10){
|
269 |
+
z_bit -= 11;
|
270 |
+
z_bit *= 3;
|
271 |
+
z_bit += 1;
|
272 |
+
z_w += 1;
|
273 |
+
} else {
|
274 |
+
z_bit *= 3;
|
275 |
+
}
|
276 |
+
} else {
|
277 |
+
z_w += 1;
|
278 |
+
}
|
279 |
+
}
|
280 |
+
|
281 |
+
unsigned int tmp1;
|
282 |
+
unsigned int tmp2;
|
283 |
+
unsigned int tmp;
|
284 |
+
unsigned int z_tmp;
|
285 |
+
|
286 |
+
while (k < BLOCKWIDTH) {
|
287 |
+
tmp1 = as_unsigned(mat[i]);
|
288 |
+
|
289 |
+
int g = (g_h + k) / groupsize;
|
290 |
+
scalar_t scale = scales[g * width + w];
|
291 |
+
scalar_t zero;
|
292 |
+
if (z_mod == 10) {
|
293 |
+
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
294 |
+
zero = scale * scalar_t((z_tmp) + 1);
|
295 |
+
} else if (z_mod == 21){
|
296 |
+
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
297 |
+
zero = scale * scalar_t((z_tmp) + 1);
|
298 |
+
} else {
|
299 |
+
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
300 |
+
}
|
301 |
+
|
302 |
+
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
303 |
+
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
304 |
+
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
305 |
+
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
306 |
+
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
307 |
+
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
308 |
+
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
309 |
+
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
310 |
+
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
311 |
+
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
312 |
+
|
313 |
+
i += width;
|
314 |
+
tmp2 = as_unsigned(mat[i]);
|
315 |
+
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
|
316 |
+
tmp2 >>= 1;
|
317 |
+
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
318 |
+
k += 11;
|
319 |
+
|
320 |
+
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
321 |
+
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
322 |
+
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
323 |
+
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
324 |
+
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
325 |
+
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
326 |
+
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
327 |
+
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
328 |
+
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
329 |
+
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
330 |
+
|
331 |
+
i += width;
|
332 |
+
tmp1 = as_unsigned(mat[i]);
|
333 |
+
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
|
334 |
+
tmp1 >>= 2;
|
335 |
+
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
336 |
+
k += 11;
|
337 |
+
|
338 |
+
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
339 |
+
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
340 |
+
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
341 |
+
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
342 |
+
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
343 |
+
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
344 |
+
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
345 |
+
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
346 |
+
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
347 |
+
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
348 |
+
|
349 |
+
i += width;
|
350 |
+
k += 10;
|
351 |
+
}
|
352 |
+
|
353 |
+
atomicAdd(&mul[b * width + w], res);
|
354 |
+
}
|
355 |
+
|
356 |
+
void vecquant4matmul_cuda(
|
357 |
+
torch::Tensor vec,
|
358 |
+
torch::Tensor mat,
|
359 |
+
torch::Tensor mul,
|
360 |
+
torch::Tensor scales,
|
361 |
+
torch::Tensor zeros,
|
362 |
+
int groupsize
|
363 |
+
) {
|
364 |
+
int batch = vec.size(0);
|
365 |
+
int vec_height = vec.size(1);
|
366 |
+
int height = mat.size(0);
|
367 |
+
int width = mat.size(1);
|
368 |
+
int zero_width = zeros.size(1);
|
369 |
+
|
370 |
+
dim3 blocks(
|
371 |
+
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
372 |
+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
373 |
+
batch
|
374 |
+
);
|
375 |
+
dim3 threads(BLOCKWIDTH);
|
376 |
+
|
377 |
+
AT_DISPATCH_FLOATING_TYPES(
|
378 |
+
vec.type(), "vecquant4matmul_cuda", ([&] {
|
379 |
+
VecQuant4MatMulKernel<<<blocks, threads>>>(
|
380 |
+
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
381 |
+
scales.data<scalar_t>(), zeros.data<int>(),
|
382 |
+
batch, vec_height, height, width, zero_width, groupsize
|
383 |
+
);
|
384 |
+
})
|
385 |
+
);
|
386 |
+
}
|
387 |
+
|
388 |
+
template <typename scalar_t>
|
389 |
+
__global__ void VecQuant4MatMulKernel(
|
390 |
+
const scalar_t* __restrict__ vec,
|
391 |
+
const int* __restrict__ mat,
|
392 |
+
scalar_t* __restrict__ mul,
|
393 |
+
const scalar_t* __restrict__ scales,
|
394 |
+
const int* __restrict__ zeros,
|
395 |
+
int batch,
|
396 |
+
int vec_height,
|
397 |
+
int height,
|
398 |
+
int width,
|
399 |
+
int zero_width,
|
400 |
+
int groupsize
|
401 |
+
) {
|
402 |
+
int b = blockIdx.z;
|
403 |
+
int h = BLOCKHEIGHT4 * blockIdx.x;
|
404 |
+
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
405 |
+
|
406 |
+
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
407 |
+
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
408 |
+
__syncthreads();
|
409 |
+
|
410 |
+
scalar_t res = 0;
|
411 |
+
int i = width * h + w;
|
412 |
+
int g_h = h * 8;
|
413 |
+
int k = 0;
|
414 |
+
|
415 |
+
int z_w = w / 8;
|
416 |
+
int z_mod = (w % 8) * 4;
|
417 |
+
|
418 |
+
unsigned int tmp;
|
419 |
+
|
420 |
+
while (k < BLOCKWIDTH) {
|
421 |
+
tmp = as_unsigned(mat[i]);
|
422 |
+
|
423 |
+
int g = (g_h + k) / groupsize;
|
424 |
+
scalar_t scale = scales[g * width + w];
|
425 |
+
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
426 |
+
|
427 |
+
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
428 |
+
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
429 |
+
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
|
430 |
+
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
|
431 |
+
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
|
432 |
+
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
|
433 |
+
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
|
434 |
+
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
|
435 |
+
|
436 |
+
i += width;
|
437 |
+
k += 8;
|
438 |
+
}
|
439 |
+
|
440 |
+
atomicAdd(&mul[b * width + w], res);
|
441 |
+
}
|
442 |
+
|
443 |
+
void vecquant8matmul_cuda(
|
444 |
+
torch::Tensor vec,
|
445 |
+
torch::Tensor mat,
|
446 |
+
torch::Tensor mul,
|
447 |
+
torch::Tensor scales,
|
448 |
+
torch::Tensor zeros,
|
449 |
+
int groupsize
|
450 |
+
) {
|
451 |
+
int batch = vec.size(0);
|
452 |
+
int vec_height = vec.size(1);
|
453 |
+
int height = mat.size(0);
|
454 |
+
int width = mat.size(1);
|
455 |
+
int zero_width = zeros.size(1);
|
456 |
+
|
457 |
+
dim3 blocks(
|
458 |
+
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
|
459 |
+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
460 |
+
batch
|
461 |
+
);
|
462 |
+
dim3 threads(BLOCKWIDTH);
|
463 |
+
|
464 |
+
AT_DISPATCH_FLOATING_TYPES(
|
465 |
+
vec.type(), "vecquant8matmul_cuda", ([&] {
|
466 |
+
VecQuant8MatMulKernel<<<blocks, threads>>>(
|
467 |
+
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
468 |
+
scales.data<scalar_t>(), zeros.data<int>(),
|
469 |
+
batch, vec_height, height, width, zero_width, groupsize
|
470 |
+
);
|
471 |
+
})
|
472 |
+
);
|
473 |
+
}
|
474 |
+
|
475 |
+
template <typename scalar_t>
|
476 |
+
__global__ void VecQuant8MatMulKernel(
|
477 |
+
const scalar_t* __restrict__ vec,
|
478 |
+
const int* __restrict__ mat,
|
479 |
+
scalar_t* __restrict__ mul,
|
480 |
+
const scalar_t* __restrict__ scales,
|
481 |
+
const int* __restrict__ zeros,
|
482 |
+
int batch,
|
483 |
+
int vec_height,
|
484 |
+
int height,
|
485 |
+
int width,
|
486 |
+
int zero_width,
|
487 |
+
int groupsize
|
488 |
+
) {
|
489 |
+
int b = blockIdx.z;
|
490 |
+
int h = BLOCKHEIGHT8 * blockIdx.x;
|
491 |
+
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
492 |
+
|
493 |
+
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
494 |
+
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
495 |
+
__syncthreads();
|
496 |
+
|
497 |
+
scalar_t res = 0;
|
498 |
+
int i = width * h + w;
|
499 |
+
int g_h = h * 4;
|
500 |
+
int k = 0;
|
501 |
+
|
502 |
+
int z_w = w / 4;
|
503 |
+
int z_mod = (w % 4) * 8;
|
504 |
+
|
505 |
+
unsigned int tmp;
|
506 |
+
|
507 |
+
while (k < BLOCKWIDTH) {
|
508 |
+
tmp = as_unsigned(mat[i]);
|
509 |
+
|
510 |
+
int g = (g_h + k) / groupsize;
|
511 |
+
scalar_t scale = scales[g * width + w];
|
512 |
+
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
513 |
+
|
514 |
+
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
515 |
+
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
516 |
+
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
|
517 |
+
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
|
518 |
+
|
519 |
+
i += width;
|
520 |
+
k += 4;
|
521 |
+
}
|
522 |
+
|
523 |
+
atomicAdd(&mul[b * width + w], res);
|
524 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"eos_token": "</s>",
|
4 |
+
"pad_token": "<pad>",
|
5 |
+
"unk_token": "<unk>"
|
6 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fa39cd4b1500feb205bcce3b9703a4373414cafe4970e0657b413f7ddd2a9d3
|
3 |
+
size 14500438
|
tokenizer_config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<s>",
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"model_max_length": 2048,
|
6 |
+
"pad_token": "<pad>",
|
7 |
+
"padding_side": "right",
|
8 |
+
"special_tokens_map_file": null,
|
9 |
+
"tokenizer_class": "BloomTokenizer",
|
10 |
+
"unk_token": "<unk>"
|
11 |
+
}
|