Andrei Panferov
commited on
Commit
•
f1a2023
1
Parent(s):
115e749
slightly faster inference
Browse files- inference.py +10 -8
inference.py
CHANGED
@@ -161,6 +161,7 @@ def forward_pass_quantized_linear(
|
|
161 |
"num_input_groups",
|
162 |
"num_input_groups_next_power_of_2",
|
163 |
"compute_in_fp32",
|
|
|
164 |
],
|
165 |
)
|
166 |
@triton.jit
|
@@ -180,6 +181,7 @@ def _aqlm_gemv_simple(
|
|
180 |
num_input_groups: tl.constexpr,
|
181 |
num_input_groups_next_power_of_2: tl.constexpr,
|
182 |
compute_in_fp32: tl.constexpr,
|
|
|
183 |
UNUSED: tl.constexpr,
|
184 |
):
|
185 |
# variables ending with "_i" mean "for i-th output unit"
|
@@ -188,11 +190,11 @@ def _aqlm_gemv_simple(
|
|
188 |
# Stage 1: load input data
|
189 |
input_vec = tl.load(
|
190 |
input_vec_ptr
|
191 |
-
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
|
192 |
-
+ tl.arange(0, in_group_size)[None, None, :],
|
193 |
-
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
|
194 |
)
|
195 |
-
# [in_features//in_group_size, 1, group_size]
|
196 |
# Note: we could simply load input_vec then reshape
|
197 |
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
|
198 |
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
|
@@ -237,19 +239,17 @@ def _aqlm_gemv_simple(
|
|
237 |
weights_i = weights_i.to(tl.float32)
|
238 |
input_vec = input_vec.to(tl.float32)
|
239 |
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
240 |
-
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
|
241 |
-
# ^-- [in_features // in_group_size, out_group_size, in_group_size]
|
242 |
|
243 |
if out_group_size == 1:
|
244 |
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
|
245 |
output_i = tl.sum(weights_i * input_vec) * scale
|
246 |
-
if
|
247 |
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
248 |
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
|
249 |
else:
|
250 |
output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
|
251 |
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
|
252 |
-
if
|
253 |
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
254 |
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
|
255 |
|
@@ -296,6 +296,7 @@ def aqlm_gemv_simple(
|
|
296 |
num_input_groups,
|
297 |
next_power_of_2(num_input_groups),
|
298 |
compute_in_fp32,
|
|
|
299 |
)
|
300 |
|
301 |
return output_vec
|
@@ -339,6 +340,7 @@ def aqlm_gemm_stupid(
|
|
339 |
num_input_groups,
|
340 |
next_power_of_2(num_input_groups),
|
341 |
compute_in_fp32,
|
|
|
342 |
)
|
343 |
|
344 |
return output
|
|
|
161 |
"num_input_groups",
|
162 |
"num_input_groups_next_power_of_2",
|
163 |
"compute_in_fp32",
|
164 |
+
"has_bias",
|
165 |
],
|
166 |
)
|
167 |
@triton.jit
|
|
|
181 |
num_input_groups: tl.constexpr,
|
182 |
num_input_groups_next_power_of_2: tl.constexpr,
|
183 |
compute_in_fp32: tl.constexpr,
|
184 |
+
has_bias: tl.constexpr,
|
185 |
UNUSED: tl.constexpr,
|
186 |
):
|
187 |
# variables ending with "_i" mean "for i-th output unit"
|
|
|
190 |
# Stage 1: load input data
|
191 |
input_vec = tl.load(
|
192 |
input_vec_ptr
|
193 |
+
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
|
194 |
+
+ tl.arange(0, in_group_size)[None, None, None, :],
|
195 |
+
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
|
196 |
)
|
197 |
+
# [in_features//in_group_size, 1, 1, group_size]
|
198 |
# Note: we could simply load input_vec then reshape
|
199 |
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
|
200 |
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
|
|
|
239 |
weights_i = weights_i.to(tl.float32)
|
240 |
input_vec = input_vec.to(tl.float32)
|
241 |
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
|
|
|
|
242 |
|
243 |
if out_group_size == 1:
|
244 |
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
|
245 |
output_i = tl.sum(weights_i * input_vec) * scale
|
246 |
+
if has_bias:
|
247 |
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
248 |
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
|
249 |
else:
|
250 |
output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
|
251 |
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
|
252 |
+
if has_bias:
|
253 |
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
254 |
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
|
255 |
|
|
|
296 |
num_input_groups,
|
297 |
next_power_of_2(num_input_groups),
|
298 |
compute_in_fp32,
|
299 |
+
bias is not None,
|
300 |
)
|
301 |
|
302 |
return output_vec
|
|
|
340 |
num_input_groups,
|
341 |
next_power_of_2(num_input_groups),
|
342 |
compute_in_fp32,
|
343 |
+
bias is not None,
|
344 |
)
|
345 |
|
346 |
return output
|