chinoll commited on
Commit
b35c785
1 Parent(s): fabd8fb

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bigscience/bloomz-3b",
3
+ "apply_residual_connection_post_layernorm": false,
4
+ "architectures": [
5
+ "BloomForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoModel": "modeling_chatsakura.SakuraForCausalLM",
9
+ "AutoModelForCausalLM": "modeling_chatsakura.SakuraForCausalLM"
10
+ },
11
+ "attention_dropout": 0.0,
12
+ "attention_softmax_in_fp32": true,
13
+ "bias_dropout_fusion": true,
14
+ "bos_token_id": 1,
15
+ "eos_token_id": 2,
16
+ "hidden_dropout": 0.0,
17
+ "hidden_size": 2560,
18
+ "initializer_range": 0.02,
19
+ "layer_norm_epsilon": 1e-05,
20
+ "masked_softmax_fusion": true,
21
+ "model_type": "bloom",
22
+ "n_head": 32,
23
+ "n_inner": null,
24
+ "n_layer": 30,
25
+ "offset_alibi": 100,
26
+ "pad_token_id": 3,
27
+ "pretraining_tp": 4,
28
+ "seq_length": 2048,
29
+ "skip_bias_add": true,
30
+ "skip_bias_add_qkv": false,
31
+ "slow_but_exact": false,
32
+ "torch_dtype": "float16",
33
+ "transformers_version": "4.27.1",
34
+ "unk_token_id": 0,
35
+ "use_cache": true,
36
+ "vocab_size": 250880
37
+ }
gptq.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+
8
+ from quant import *
9
+
10
+
11
+ DEBUG = False
12
+
13
+ torch.backends.cuda.matmul.allow_tf32 = False
14
+ torch.backends.cudnn.allow_tf32 = False
15
+
16
+
17
+ class GPTQ:
18
+ def __init__(self, layer):
19
+ self.layer = layer
20
+ self.dev = self.layer.weight.device
21
+ W = layer.weight.data.clone()
22
+ if isinstance(self.layer, nn.Conv2d):
23
+ W = W.flatten(1)
24
+ if isinstance(self.layer, transformers.Conv1D):
25
+ W = W.t()
26
+ self.rows = W.shape[0]
27
+ self.columns = W.shape[1]
28
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
29
+ self.nsamples = 0
30
+
31
+ def add_batch(self, inp, out):
32
+ if DEBUG:
33
+ self.inp1 = inp
34
+ self.out1 = out
35
+ if len(inp.shape) == 2:
36
+ inp = inp.unsqueeze(0)
37
+ tmp = inp.shape[0]
38
+ if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
39
+ if len(inp.shape) == 3:
40
+ inp = inp.reshape((-1, inp.shape[-1]))
41
+ inp = inp.t()
42
+ if isinstance(self.layer, nn.Conv2d):
43
+ unfold = nn.Unfold(
44
+ self.layer.kernel_size,
45
+ dilation=self.layer.dilation,
46
+ padding=self.layer.padding,
47
+ stride=self.layer.stride
48
+ )
49
+ inp = unfold(inp)
50
+ inp = inp.permute([1, 0, 2])
51
+ inp = inp.flatten(1)
52
+ self.H *= self.nsamples / (self.nsamples + tmp)
53
+ self.nsamples += tmp
54
+ # inp = inp.float()
55
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
56
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
57
+ self.H += inp.matmul(inp.t())
58
+
59
+ def fasterquant(
60
+ self, blocksize=128, percdamp=.01, groupsize=-1
61
+ ):
62
+ W = self.layer.weight.data.clone()
63
+ if isinstance(self.layer, nn.Conv2d):
64
+ W = W.flatten(1)
65
+ if isinstance(self.layer, transformers.Conv1D):
66
+ W = W.t()
67
+ W = W.float()
68
+
69
+ tick = time.time()
70
+
71
+ if not self.quantizer.ready():
72
+ self.quantizer.find_params(W, weight=True)
73
+
74
+ H = self.H
75
+ del self.H
76
+ dead = torch.diag(H) == 0
77
+ H[dead, dead] = 1
78
+ W[:, dead] = 0
79
+
80
+ Losses = torch.zeros_like(W)
81
+ Q = torch.zeros_like(W)
82
+
83
+ damp = percdamp * torch.mean(torch.diag(H))
84
+ diag = torch.arange(self.columns, device=self.dev)
85
+ H[diag, diag] += damp
86
+ H = torch.linalg.cholesky(H)
87
+ H = torch.cholesky_inverse(H)
88
+ H = torch.linalg.cholesky(H, upper=True)
89
+ Hinv = H
90
+
91
+ scale = []
92
+ zero = []
93
+ now_idx = 1
94
+
95
+ for i1 in range(0, self.columns, blocksize):
96
+ i2 = min(i1 + blocksize, self.columns)
97
+ count = i2 - i1
98
+
99
+ W1 = W[:, i1:i2].clone()
100
+ Q1 = torch.zeros_like(W1)
101
+ Err1 = torch.zeros_like(W1)
102
+ Losses1 = torch.zeros_like(W1)
103
+ Hinv1 = Hinv[i1:i2, i1:i2]
104
+
105
+ for i in range(count):
106
+ w = W1[:, i]
107
+ d = Hinv1[i, i]
108
+
109
+ if groupsize != -1:
110
+ if (i1 + i) % groupsize == 0:
111
+ self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
112
+
113
+ if ((i1 + i) // groupsize) - now_idx == -1:
114
+ scale.append(self.quantizer.scale)
115
+ zero.append(self.quantizer.zero)
116
+ now_idx += 1
117
+
118
+ q = quantize(
119
+ w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
120
+ ).flatten()
121
+ Q1[:, i] = q
122
+ Losses1[:, i] = (w - q) ** 2 / d ** 2
123
+
124
+ err1 = (w - q) / d
125
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
126
+ Err1[:, i] = err1
127
+
128
+ Q[:, i1:i2] = Q1
129
+ Losses[:, i1:i2] = Losses1 / 2
130
+
131
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
132
+
133
+ if DEBUG:
134
+ self.layer.weight.data[:, :i2] = Q[:, :i2]
135
+ self.layer.weight.data[:, i2:] = W[:, i2:]
136
+ print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
137
+ print(torch.sum(Losses))
138
+
139
+ torch.cuda.synchronize()
140
+ print('time %.2f' % (time.time() - tick))
141
+ print('error', torch.sum(Losses).item())
142
+
143
+ if isinstance(self.layer, transformers.Conv1D):
144
+ Q = Q.t()
145
+ self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
146
+ if DEBUG:
147
+ print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
148
+
149
+ if scale == []:
150
+ scale.append(self.quantizer.scale)
151
+ zero.append(self.quantizer.zero)
152
+ scale = torch.cat(scale,dim=1)
153
+ zero = torch.cat(zero,dim=1)
154
+ return scale,zero
155
+
156
+ def free(self):
157
+ if DEBUG:
158
+ self.inp1 = None
159
+ self.out1 = None
160
+ self.H = None
161
+ self.Losses = None
162
+ self.Trace = None
163
+ torch.cuda.empty_cache()
modeling_chatsakura.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from gptq import *
5
+ from modelutils import *
6
+ from quant import *
7
+ from transformers import BloomForCausalLM as LM
8
+
9
+ class SakuraForCausalLM(LM):
10
+ def __init__(self,*args,**kwargs):
11
+ def noop(*args, **kwargs):
12
+ pass
13
+ torch.nn.init.kaiming_uniform_ = noop
14
+ torch.nn.init.uniform_ = noop
15
+ torch.nn.init.normal_ = noop
16
+ torch.set_default_dtype(torch.half)
17
+ transformers.modeling_utils._init_weights = False
18
+ torch.set_default_dtype(torch.half)
19
+ super().__init__(*args,**kwargs)
20
+ torch.set_default_dtype(torch.float)
21
+ self.eval()
22
+ layers = find_layers(self)
23
+ for name in ['lm_head']:
24
+ if name in layers:
25
+ del layers[name]
26
+ make_quant(self, layers, 4, -1)
modelutils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
4
+ if type(module) in layers:
5
+ return {name: module}
6
+ res = {}
7
+ for name1, child in module.named_children():
8
+ res.update(find_layers(
9
+ child, layers=layers, name=name + '.' + name1 if name != '' else name1
10
+ ))
11
+ return res
quant.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+ def quantize(x, scale, zero, maxq):
7
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
8
+ return scale * (q - zero)
9
+
10
+ class Quantizer(nn.Module):
11
+
12
+ def __init__(self, shape=1):
13
+ super(Quantizer, self).__init__()
14
+ self.register_buffer('maxq', torch.tensor(0))
15
+ self.register_buffer('scale', torch.zeros(shape))
16
+ self.register_buffer('zero', torch.zeros(shape))
17
+
18
+ def configure(
19
+ self,
20
+ bits, perchannel=False, sym=True,
21
+ mse=False, norm=2.4, grid=100, maxshrink=.8
22
+ ):
23
+ self.maxq = torch.tensor(2 ** bits - 1)
24
+ self.perchannel = perchannel
25
+ self.sym = sym
26
+ self.mse = mse
27
+ self.norm = norm
28
+ self.grid = grid
29
+ self.maxshrink = maxshrink
30
+
31
+ def find_params(self, x, weight=False):
32
+ dev = x.device
33
+ self.maxq = self.maxq.to(dev)
34
+
35
+ shape = x.shape
36
+ if self.perchannel:
37
+ if weight:
38
+ x = x.flatten(1)
39
+ else:
40
+ if len(shape) == 4:
41
+ x = x.permute([1, 0, 2, 3])
42
+ x = x.flatten(1)
43
+ if len(shape) == 3:
44
+ x = x.reshape((-1, shape[-1])).t()
45
+ if len(shape) == 2:
46
+ x = x.t()
47
+ else:
48
+ x = x.flatten().unsqueeze(0)
49
+
50
+ tmp = torch.zeros(x.shape[0], device=dev)
51
+ xmin = torch.minimum(x.min(1)[0], tmp)
52
+ xmax = torch.maximum(x.max(1)[0], tmp)
53
+
54
+ if self.sym:
55
+ xmax = torch.maximum(torch.abs(xmin), xmax)
56
+ tmp = xmin < 0
57
+ if torch.any(tmp):
58
+ xmin[tmp] = -xmax[tmp]
59
+ tmp = (xmin == 0) & (xmax == 0)
60
+ xmin[tmp] = -1
61
+ xmax[tmp] = +1
62
+
63
+ self.scale = (xmax - xmin) / self.maxq
64
+ if self.sym:
65
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
66
+ else:
67
+ self.zero = torch.round(-xmin / self.scale)
68
+
69
+ if self.mse:
70
+ best = torch.full([x.shape[0]], float('inf'), device=dev)
71
+ for i in range(int(self.maxshrink * self.grid)):
72
+ p = 1 - i / self.grid
73
+ xmin1 = p * xmin
74
+ xmax1 = p * xmax
75
+ scale1 = (xmax1 - xmin1) / self.maxq
76
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
77
+ q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
78
+ q -= x
79
+ q.abs_()
80
+ q.pow_(self.norm)
81
+ err = torch.sum(q, 1)
82
+ tmp = err < best
83
+ if torch.any(tmp):
84
+ best[tmp] = err[tmp]
85
+ self.scale[tmp] = scale1[tmp]
86
+ self.zero[tmp] = zero1[tmp]
87
+ if not self.perchannel:
88
+ if weight:
89
+ tmp = shape[0]
90
+ else:
91
+ tmp = shape[1] if len(shape) != 3 else shape[2]
92
+ self.scale = self.scale.repeat(tmp)
93
+ self.zero = self.zero.repeat(tmp)
94
+
95
+ if weight:
96
+ shape = [-1] + [1] * (len(shape) - 1)
97
+ self.scale = self.scale.reshape(shape)
98
+ self.zero = self.zero.reshape(shape)
99
+ return
100
+ if len(shape) == 4:
101
+ self.scale = self.scale.reshape((1, -1, 1, 1))
102
+ self.zero = self.zero.reshape((1, -1, 1, 1))
103
+ if len(shape) == 3:
104
+ self.scale = self.scale.reshape((1, 1, -1))
105
+ self.zero = self.zero.reshape((1, 1, -1))
106
+ if len(shape) == 2:
107
+ self.scale = self.scale.unsqueeze(0)
108
+ self.zero = self.zero.unsqueeze(0)
109
+
110
+ def quantize(self, x):
111
+ if self.ready():
112
+ return quantize(x, self.scale, self.zero, self.maxq)
113
+ return x
114
+
115
+ def enabled(self):
116
+ return self.maxq > 0
117
+
118
+ def ready(self):
119
+ return torch.all(self.scale != 0)
120
+
121
+
122
+ try:
123
+ import quant_cuda
124
+ except:
125
+ import os
126
+ import sys
127
+ argv = sys.argv
128
+ sys.argv = ['quant.py','install']
129
+ dir_path = os.path.dirname(os.path.realpath(__file__))
130
+ from setuptools import setup, Extension
131
+ from torch.utils import cpp_extension
132
+ os.chdir(dir_path)
133
+ setup(
134
+ name='quant_cuda',
135
+ ext_modules=[cpp_extension.CUDAExtension(
136
+ 'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
137
+ )],
138
+ cmdclass={'build_ext': cpp_extension.BuildExtension}
139
+ )
140
+ os.chdir(os.getcwd())
141
+ sys.argv = argv
142
+ for i in sys.path:
143
+ if i.endswith("site-packages"):
144
+ for j in os.listdir(i):
145
+ if j.find("quant_cuda") != -1:
146
+ sys.path.append(os.path.join(i,j))
147
+ break
148
+ break
149
+ import quant_cuda
150
+
151
+
152
+ # Assumes layer is perfectly divisible into 256 * 256 blocks
153
+ class QuantLinear(nn.Module):
154
+ def __init__(self, bits, groupsize, infeatures, outfeatures):
155
+ super().__init__()
156
+ if bits not in [2,3,4,8]:
157
+ raise NotImplementedError("Only 2,3,4,8 bits are supported.")
158
+ self.infeatures = infeatures
159
+ self.outfeatures = outfeatures
160
+ self.bits = bits
161
+ if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))):
162
+ raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
163
+ groupsize = groupsize if groupsize != -1 else infeatures
164
+ self.groupsize = groupsize
165
+ self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int))
166
+ self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures)))
167
+ self.register_buffer('bias', torch.zeros(outfeatures))
168
+ self.register_buffer(
169
+ 'qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
170
+ )
171
+ self._initialized_quant_state = False
172
+
173
+ def pack(self, linear, scales, zeros):
174
+ scales = scales.t().contiguous()
175
+ zeros = zeros.t().contiguous()
176
+ scale_zeros = zeros * scales
177
+ self.scales = scales.clone()
178
+ if linear.bias is not None:
179
+ self.bias = linear.bias.clone()
180
+
181
+ intweight = []
182
+ for idx in range(self.infeatures):
183
+ g_idx = idx // self.groupsize
184
+ intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None])
185
+ intweight = torch.cat(intweight,dim=1)
186
+ intweight = intweight.t().contiguous()
187
+ intweight = intweight.numpy().astype(np.uint32)
188
+ qweight = np.zeros(
189
+ (intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32
190
+ )
191
+ i = 0
192
+ row = 0
193
+ while row < qweight.shape[0]:
194
+ if self.bits in [2,4,8]:
195
+ for j in range(i, i + (32//self.bits)):
196
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
197
+ i += 32//self.bits
198
+ row += 1
199
+ elif self.bits == 3:
200
+ for j in range(i, i + 10):
201
+ qweight[row] |= intweight[j] << (3 * (j - i))
202
+ i += 10
203
+ qweight[row] |= intweight[i] << 30
204
+ row += 1
205
+ qweight[row] |= (intweight[i] >> 2) & 1
206
+ i += 1
207
+ for j in range(i, i + 10):
208
+ qweight[row] |= intweight[j] << (3 * (j - i) + 1)
209
+ i += 10
210
+ qweight[row] |= intweight[i] << 31
211
+ row += 1
212
+ qweight[row] |= (intweight[i] >> 1) & 0x3
213
+ i += 1
214
+ for j in range(i, i + 10):
215
+ qweight[row] |= intweight[j] << (3 * (j - i) + 2)
216
+ i += 10
217
+ row += 1
218
+ else:
219
+ raise NotImplementedError("Only 2,3,4,8 bits are supported.")
220
+
221
+ qweight = qweight.astype(np.int32)
222
+ self.qweight = torch.from_numpy(qweight)
223
+
224
+ zeros -= 1;
225
+ zeros = zeros.numpy().astype(np.uint32)
226
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
227
+ i = 0
228
+ col = 0
229
+ while col < qzeros.shape[1]:
230
+ if self.bits in [2,4,8]:
231
+ for j in range(i, i + (32//self.bits)):
232
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
233
+ i += 32//self.bits
234
+ col += 1
235
+ elif self.bits == 3:
236
+ for j in range(i, i + 10):
237
+ qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
238
+ i += 10
239
+ qzeros[:, col] |= zeros[:, i] << 30
240
+ col += 1
241
+ qzeros[:, col] |= (zeros[:, i] >> 2) & 1
242
+ i += 1
243
+ for j in range(i, i + 10):
244
+ qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
245
+ i += 10
246
+ qzeros[:, col] |= zeros[:, i] << 31
247
+ col += 1
248
+ qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
249
+ i += 1
250
+ for j in range(i, i + 10):
251
+ qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
252
+ i += 10
253
+ col += 1
254
+ else:
255
+ raise NotImplementedError("Only 2,3,4,8 bits are supported.")
256
+
257
+ qzeros = qzeros.astype(np.int32)
258
+ self.qzeros = torch.from_numpy(qzeros)
259
+
260
+ def forward(self, x):
261
+ intermediate_dtype = torch.float32
262
+
263
+ if not self._initialized_quant_state:
264
+ # Do we even have a bias? Check for at least one non-zero element.
265
+ if self.bias is not None and bool(torch.any(self.bias != 0)):
266
+ # Then make sure it's the right type.
267
+ self.bias.data = self.bias.data.to(intermediate_dtype)
268
+ else:
269
+ self.bias = None
270
+
271
+ outshape = list(x.shape)
272
+ outshape[-1] = self.outfeatures
273
+ x = x.reshape(-1, x.shape[-1])
274
+ if self.bias is None:
275
+ y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
276
+ else:
277
+ y = self.bias.clone().repeat(x.shape[0], 1)
278
+
279
+ output_dtype = x.dtype
280
+ x = x.to(intermediate_dtype)
281
+ if self.bits == 2:
282
+ quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
283
+ elif self.bits == 3:
284
+ quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
285
+ elif self.bits == 4:
286
+ quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
287
+ elif self.bits == 8:
288
+ quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
289
+ else:
290
+ raise NotImplementedError("Only 2,3,4,8 bits are supported.")
291
+ y = y.to(output_dtype)
292
+ return y.reshape(outshape)
293
+
294
+ def make_quant(module, names, bits, groupsize, name=''):
295
+ if isinstance(module, QuantLinear):
296
+ return
297
+ for attr in dir(module):
298
+ tmp = getattr(module, attr)
299
+ name1 = name + '.' + attr if name != '' else attr
300
+ if name1 in names:
301
+ setattr(
302
+ module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)
303
+ )
304
+ for name1, child in module.named_children():
305
+ make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
quant_cuda.cpp ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <torch/python.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ void vecquant2matmul_cuda(
6
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
7
+ torch::Tensor scales, torch::Tensor zeros,
8
+ int groupsize
9
+ );
10
+
11
+ void vecquant2matmul(
12
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
13
+ torch::Tensor scales, torch::Tensor zeros,
14
+ int groupsize
15
+ ) {
16
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
17
+ vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize);
18
+ }
19
+
20
+ void vecquant3matmul_cuda(
21
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
22
+ torch::Tensor scales, torch::Tensor zeros,
23
+ int groupsize
24
+ );
25
+
26
+ void vecquant3matmul(
27
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
28
+ torch::Tensor scales, torch::Tensor zeros,
29
+ int groupsize
30
+ ) {
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
32
+ vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
33
+ }
34
+
35
+ void vecquant4matmul_cuda(
36
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
37
+ torch::Tensor scales, torch::Tensor zeros,
38
+ int groupsize
39
+ );
40
+
41
+ void vecquant4matmul(
42
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
43
+ torch::Tensor scales, torch::Tensor zeros,
44
+ int groupsize
45
+ ) {
46
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
47
+ vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
48
+ }
49
+
50
+ void vecquant8matmul_cuda(
51
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
52
+ torch::Tensor scales, torch::Tensor zeros,
53
+ int groupsize
54
+ );
55
+
56
+ void vecquant8matmul(
57
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
58
+ torch::Tensor scales, torch::Tensor zeros,
59
+ int groupsize
60
+ ) {
61
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
62
+ vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
63
+ }
64
+
65
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
66
+ m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
67
+ m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
68
+ m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
69
+ m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
70
+ }
quant_cuda_kernel.cu ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <torch/python.h>
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ // atomicAdd for double-precision floating-point numbers on hardware with
7
+ // compute capability < 6.0 from:
8
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
9
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
10
+ __device__ double atomicAdd(
11
+ double* address,
12
+ double val
13
+ ) {
14
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
15
+ unsigned long long int old = *address_as_ull, assumed;
16
+
17
+ do {
18
+ assumed = old;
19
+ old = atomicCAS(
20
+ address_as_ull,
21
+ assumed,
22
+ __double_as_longlong(val + __longlong_as_double(assumed))
23
+ );
24
+
25
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
26
+ } while (assumed != old);
27
+
28
+ return __longlong_as_double(old);
29
+ }
30
+ #endif
31
+
32
+ template <typename scalar_t>
33
+ __global__ void VecQuant2MatMulKernel(
34
+ const scalar_t* __restrict__ vec,
35
+ const int* __restrict__ mat,
36
+ scalar_t* __restrict__ mul,
37
+ const scalar_t* __restrict__ scales,
38
+ const int* __restrict__ zeros,
39
+ int batch,
40
+ int vec_height,
41
+ int height,
42
+ int width,
43
+ int zero_width,
44
+ int groupsize
45
+ );
46
+
47
+ template <typename scalar_t>
48
+ __global__ void VecQuant3MatMulKernel(
49
+ const scalar_t* __restrict__ vec,
50
+ const int* __restrict__ mat,
51
+ scalar_t* __restrict__ mul,
52
+ const scalar_t* __restrict__ scales,
53
+ const int* __restrict__ zeros,
54
+ int batch,
55
+ int vec_height,
56
+ int height,
57
+ int width,
58
+ int zero_width,
59
+ int groupsize
60
+ );
61
+
62
+ template <typename scalar_t>
63
+ __global__ void VecQuant4MatMulKernel(
64
+ const scalar_t* __restrict__ vec,
65
+ const int* __restrict__ mat,
66
+ scalar_t* __restrict__ mul,
67
+ const scalar_t* __restrict__ scales,
68
+ const int* __restrict__ zeros,
69
+ int batch,
70
+ int vec_height,
71
+ int height,
72
+ int width,
73
+ int zero_width,
74
+ int groupsize
75
+ );
76
+
77
+ template <typename scalar_t>
78
+ __global__ void VecQuant8MatMulKernel(
79
+ const scalar_t* __restrict__ vec,
80
+ const int* __restrict__ mat,
81
+ scalar_t* __restrict__ mul,
82
+ const scalar_t* __restrict__ scales,
83
+ const int* __restrict__ zeros,
84
+ int batch,
85
+ int vec_height,
86
+ int height,
87
+ int width,
88
+ int zero_width,
89
+ int groupsize
90
+ );
91
+
92
+ const int BLOCKWIDTH = 256;
93
+ const int BLOCKHEIGHT2 = 16;
94
+ const int BLOCKHEIGHT3 = 24;
95
+ const int BLOCKHEIGHT4 = 32;
96
+ const int BLOCKHEIGHT8 = 64;
97
+
98
+ __device__ inline unsigned int as_unsigned(int i) {
99
+ return *reinterpret_cast<unsigned int*>(&i);
100
+ }
101
+
102
+ void vecquant2matmul_cuda(
103
+ torch::Tensor vec,
104
+ torch::Tensor mat,
105
+ torch::Tensor mul,
106
+ torch::Tensor scales,
107
+ torch::Tensor zeros,
108
+ int groupsize
109
+ ) {
110
+ int batch = vec.size(0);
111
+ int vec_height = vec.size(1);
112
+ int height = mat.size(0);
113
+ int width = mat.size(1);
114
+ int zero_width = zeros.size(1);
115
+
116
+ dim3 blocks(
117
+ (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
118
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
119
+ batch
120
+ );
121
+ dim3 threads(BLOCKWIDTH);
122
+
123
+ AT_DISPATCH_FLOATING_TYPES(
124
+ vec.type(), "vecquant2matmul_cuda", ([&] {
125
+ VecQuant2MatMulKernel<<<blocks, threads>>>(
126
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
127
+ scales.data<scalar_t>(), zeros.data<int>(),
128
+ batch, vec_height, height, width, zero_width, groupsize
129
+ );
130
+ })
131
+ );
132
+ }
133
+
134
+ template <typename scalar_t>
135
+ __global__ void VecQuant2MatMulKernel(
136
+ const scalar_t* __restrict__ vec,
137
+ const int* __restrict__ mat,
138
+ scalar_t* __restrict__ mul,
139
+ const scalar_t* __restrict__ scales,
140
+ const int* __restrict__ zeros,
141
+ int batch,
142
+ int vec_height,
143
+ int height,
144
+ int width,
145
+ int zero_width,
146
+ int groupsize
147
+ ) {
148
+ int b = blockIdx.z;
149
+ int h = BLOCKHEIGHT2 * blockIdx.x;
150
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
151
+
152
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
153
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
154
+ __syncthreads();
155
+
156
+ scalar_t res = 0;
157
+ int i = width * h + w;
158
+ int g_h = h * 16;
159
+ int k = 0;
160
+
161
+ int z_w = w / 16;
162
+ int z_mod = (w % 16) * 2;
163
+
164
+ unsigned int tmp;
165
+
166
+ while (k < BLOCKWIDTH) {
167
+ tmp = as_unsigned(mat[i]);
168
+
169
+ int g = (g_h + k) / groupsize;
170
+ scalar_t scale = scales[g * width + w];
171
+ scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
172
+
173
+ res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
174
+ res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
175
+ res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
176
+ res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
177
+ res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
178
+ res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
179
+ res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
180
+ res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
181
+ res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
182
+ res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
183
+ res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
184
+ res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
185
+ res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
186
+ res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
187
+ res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
188
+ res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
189
+
190
+ i += width;
191
+ k += 16;
192
+ }
193
+
194
+ atomicAdd(&mul[b * width + w], res);
195
+ }
196
+
197
+ void vecquant3matmul_cuda(
198
+ torch::Tensor vec,
199
+ torch::Tensor mat,
200
+ torch::Tensor mul,
201
+ torch::Tensor scales,
202
+ torch::Tensor zeros,
203
+ int groupsize
204
+ ) {
205
+ int batch = vec.size(0);
206
+ int vec_height = vec.size(1);
207
+ int height = mat.size(0);
208
+ int width = mat.size(1);
209
+ int zero_width = zeros.size(1);
210
+
211
+ dim3 blocks(
212
+ (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
213
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
214
+ batch
215
+ );
216
+ dim3 threads(BLOCKWIDTH);
217
+
218
+ AT_DISPATCH_FLOATING_TYPES(
219
+ vec.type(), "vecquant3matmul_cuda", ([&] {
220
+ VecQuant3MatMulKernel<<<blocks, threads>>>(
221
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
222
+ scales.data<scalar_t>(), zeros.data<int>(),
223
+ batch, vec_height, height, width, zero_width, groupsize
224
+ );
225
+ })
226
+ );
227
+ }
228
+
229
+ template <typename scalar_t>
230
+ __global__ void VecQuant3MatMulKernel(
231
+ const scalar_t* __restrict__ vec,
232
+ const int* __restrict__ mat,
233
+ scalar_t* __restrict__ mul,
234
+ const scalar_t* __restrict__ scales,
235
+ const int* __restrict__ zeros,
236
+ int batch,
237
+ int vec_height,
238
+ int height,
239
+ int width,
240
+ int zero_width,
241
+ int groupsize
242
+ ) {
243
+ int b = blockIdx.z;
244
+ int h = BLOCKHEIGHT3 * blockIdx.x;
245
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
246
+
247
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
248
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
249
+ __syncthreads();
250
+
251
+ scalar_t res = 0;
252
+ int i = width * h + w;
253
+ int g_h = (h / 3) * 32;
254
+ int k = 0;
255
+
256
+ int z_w = (w / 32) * 3; // ((w / 256) * 24) / 3
257
+ int z_mod = w % 32;
258
+ int z_bit;
259
+
260
+ if (z_mod != 10){
261
+ if (z_mod != 21){
262
+ z_bit = z_mod;
263
+ if (z_bit > 21){
264
+ z_bit -= 22;
265
+ z_bit *= 3;
266
+ z_bit += 2;
267
+ z_w += 2;
268
+ } else if (z_bit > 10){
269
+ z_bit -= 11;
270
+ z_bit *= 3;
271
+ z_bit += 1;
272
+ z_w += 1;
273
+ } else {
274
+ z_bit *= 3;
275
+ }
276
+ } else {
277
+ z_w += 1;
278
+ }
279
+ }
280
+
281
+ unsigned int tmp1;
282
+ unsigned int tmp2;
283
+ unsigned int tmp;
284
+ unsigned int z_tmp;
285
+
286
+ while (k < BLOCKWIDTH) {
287
+ tmp1 = as_unsigned(mat[i]);
288
+
289
+ int g = (g_h + k) / groupsize;
290
+ scalar_t scale = scales[g * width + w];
291
+ scalar_t zero;
292
+ if (z_mod == 10) {
293
+ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
294
+ zero = scale * scalar_t((z_tmp) + 1);
295
+ } else if (z_mod == 21){
296
+ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
297
+ zero = scale * scalar_t((z_tmp) + 1);
298
+ } else {
299
+ zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
300
+ }
301
+
302
+ res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
303
+ res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
304
+ res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
305
+ res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
306
+ res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
307
+ res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
308
+ res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
309
+ res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
310
+ res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
311
+ res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
312
+
313
+ i += width;
314
+ tmp2 = as_unsigned(mat[i]);
315
+ tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
316
+ tmp2 >>= 1;
317
+ res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
318
+ k += 11;
319
+
320
+ res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
321
+ res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
322
+ res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
323
+ res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
324
+ res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
325
+ res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
326
+ res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
327
+ res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
328
+ res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
329
+ res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
330
+
331
+ i += width;
332
+ tmp1 = as_unsigned(mat[i]);
333
+ tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
334
+ tmp1 >>= 2;
335
+ res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
336
+ k += 11;
337
+
338
+ res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
339
+ res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
340
+ res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
341
+ res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
342
+ res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
343
+ res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
344
+ res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
345
+ res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
346
+ res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
347
+ res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
348
+
349
+ i += width;
350
+ k += 10;
351
+ }
352
+
353
+ atomicAdd(&mul[b * width + w], res);
354
+ }
355
+
356
+ void vecquant4matmul_cuda(
357
+ torch::Tensor vec,
358
+ torch::Tensor mat,
359
+ torch::Tensor mul,
360
+ torch::Tensor scales,
361
+ torch::Tensor zeros,
362
+ int groupsize
363
+ ) {
364
+ int batch = vec.size(0);
365
+ int vec_height = vec.size(1);
366
+ int height = mat.size(0);
367
+ int width = mat.size(1);
368
+ int zero_width = zeros.size(1);
369
+
370
+ dim3 blocks(
371
+ (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
372
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
373
+ batch
374
+ );
375
+ dim3 threads(BLOCKWIDTH);
376
+
377
+ AT_DISPATCH_FLOATING_TYPES(
378
+ vec.type(), "vecquant4matmul_cuda", ([&] {
379
+ VecQuant4MatMulKernel<<<blocks, threads>>>(
380
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
381
+ scales.data<scalar_t>(), zeros.data<int>(),
382
+ batch, vec_height, height, width, zero_width, groupsize
383
+ );
384
+ })
385
+ );
386
+ }
387
+
388
+ template <typename scalar_t>
389
+ __global__ void VecQuant4MatMulKernel(
390
+ const scalar_t* __restrict__ vec,
391
+ const int* __restrict__ mat,
392
+ scalar_t* __restrict__ mul,
393
+ const scalar_t* __restrict__ scales,
394
+ const int* __restrict__ zeros,
395
+ int batch,
396
+ int vec_height,
397
+ int height,
398
+ int width,
399
+ int zero_width,
400
+ int groupsize
401
+ ) {
402
+ int b = blockIdx.z;
403
+ int h = BLOCKHEIGHT4 * blockIdx.x;
404
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
405
+
406
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
407
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
408
+ __syncthreads();
409
+
410
+ scalar_t res = 0;
411
+ int i = width * h + w;
412
+ int g_h = h * 8;
413
+ int k = 0;
414
+
415
+ int z_w = w / 8;
416
+ int z_mod = (w % 8) * 4;
417
+
418
+ unsigned int tmp;
419
+
420
+ while (k < BLOCKWIDTH) {
421
+ tmp = as_unsigned(mat[i]);
422
+
423
+ int g = (g_h + k) / groupsize;
424
+ scalar_t scale = scales[g * width + w];
425
+ scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
426
+
427
+ res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
428
+ res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
429
+ res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
430
+ res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
431
+ res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
432
+ res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
433
+ res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
434
+ res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
435
+
436
+ i += width;
437
+ k += 8;
438
+ }
439
+
440
+ atomicAdd(&mul[b * width + w], res);
441
+ }
442
+
443
+ void vecquant8matmul_cuda(
444
+ torch::Tensor vec,
445
+ torch::Tensor mat,
446
+ torch::Tensor mul,
447
+ torch::Tensor scales,
448
+ torch::Tensor zeros,
449
+ int groupsize
450
+ ) {
451
+ int batch = vec.size(0);
452
+ int vec_height = vec.size(1);
453
+ int height = mat.size(0);
454
+ int width = mat.size(1);
455
+ int zero_width = zeros.size(1);
456
+
457
+ dim3 blocks(
458
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
459
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
460
+ batch
461
+ );
462
+ dim3 threads(BLOCKWIDTH);
463
+
464
+ AT_DISPATCH_FLOATING_TYPES(
465
+ vec.type(), "vecquant8matmul_cuda", ([&] {
466
+ VecQuant8MatMulKernel<<<blocks, threads>>>(
467
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
468
+ scales.data<scalar_t>(), zeros.data<int>(),
469
+ batch, vec_height, height, width, zero_width, groupsize
470
+ );
471
+ })
472
+ );
473
+ }
474
+
475
+ template <typename scalar_t>
476
+ __global__ void VecQuant8MatMulKernel(
477
+ const scalar_t* __restrict__ vec,
478
+ const int* __restrict__ mat,
479
+ scalar_t* __restrict__ mul,
480
+ const scalar_t* __restrict__ scales,
481
+ const int* __restrict__ zeros,
482
+ int batch,
483
+ int vec_height,
484
+ int height,
485
+ int width,
486
+ int zero_width,
487
+ int groupsize
488
+ ) {
489
+ int b = blockIdx.z;
490
+ int h = BLOCKHEIGHT8 * blockIdx.x;
491
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
492
+
493
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
494
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
495
+ __syncthreads();
496
+
497
+ scalar_t res = 0;
498
+ int i = width * h + w;
499
+ int g_h = h * 4;
500
+ int k = 0;
501
+
502
+ int z_w = w / 4;
503
+ int z_mod = (w % 4) * 8;
504
+
505
+ unsigned int tmp;
506
+
507
+ while (k < BLOCKWIDTH) {
508
+ tmp = as_unsigned(mat[i]);
509
+
510
+ int g = (g_h + k) / groupsize;
511
+ scalar_t scale = scales[g * width + w];
512
+ scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
513
+
514
+ res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
515
+ res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
516
+ res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
517
+ res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
518
+
519
+ i += width;
520
+ k += 4;
521
+ }
522
+
523
+ atomicAdd(&mul[b * width + w], res);
524
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fa39cd4b1500feb205bcce3b9703a4373414cafe4970e0657b413f7ddd2a9d3
3
+ size 14500438
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "model_max_length": 2048,
6
+ "pad_token": "<pad>",
7
+ "padding_side": "right",
8
+ "special_tokens_map_file": null,
9
+ "tokenizer_class": "BloomTokenizer",
10
+ "unk_token": "<unk>"
11
+ }