Andrei Panferov
commited on
Commit
•
03ea233
1
Parent(s):
f1a2023
inference lib
Browse files- inference.py +0 -377
- modeling_llama_aqlm.py +8 -8
inference.py
DELETED
@@ -1,377 +0,0 @@
|
|
1 |
-
""" This file serves as the single entry point to efficiently run FinalizedQuantizedLinear layers"""
|
2 |
-
import functools
|
3 |
-
import os
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
import triton
|
10 |
-
import triton.language as tl
|
11 |
-
|
12 |
-
|
13 |
-
class FinalizedQuantizedLinear(nn.Module):
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
in_features: int,
|
17 |
-
out_features: int,
|
18 |
-
in_group_size: int,
|
19 |
-
out_group_size: int,
|
20 |
-
num_codebooks: int,
|
21 |
-
nbits_per_codebook: int,
|
22 |
-
bias=True,
|
23 |
-
device=None,
|
24 |
-
dtype=None,
|
25 |
-
):
|
26 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
-
super().__init__()
|
28 |
-
self.in_features = in_features
|
29 |
-
self.out_features = out_features
|
30 |
-
|
31 |
-
assert self.in_features % in_group_size == 0
|
32 |
-
assert self.out_features % out_group_size == 0
|
33 |
-
num_out_groups = out_features // out_group_size
|
34 |
-
num_in_groups = in_features // in_group_size
|
35 |
-
self.out_group_size, self.in_group_size = out_group_size, in_group_size
|
36 |
-
self.num_codebooks = num_codebooks
|
37 |
-
self.nbits_per_codebook = nbits_per_codebook
|
38 |
-
self.codebook_size = 2**nbits_per_codebook
|
39 |
-
|
40 |
-
# CODES & CODEBOOKS
|
41 |
-
self.codebooks = nn.Parameter(
|
42 |
-
torch.empty(
|
43 |
-
(num_codebooks, self.codebook_size, out_group_size, in_group_size),
|
44 |
-
**factory_kwargs,
|
45 |
-
),
|
46 |
-
requires_grad=True,
|
47 |
-
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
48 |
-
self.codes = nn.Parameter(
|
49 |
-
torch.empty(
|
50 |
-
(num_out_groups, num_in_groups, num_codebooks),
|
51 |
-
device=device,
|
52 |
-
dtype=get_int_dtype(nbits_per_codebook),
|
53 |
-
),
|
54 |
-
requires_grad=False,
|
55 |
-
) # [num_out_groups, num_in_groups, num_codebooks]
|
56 |
-
|
57 |
-
# SCALES
|
58 |
-
self.scales = nn.Parameter(
|
59 |
-
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True
|
60 |
-
) # [num_out_groups, num_in_groups, 1, 1] if scale_nbits > 0 else [num_out_groups, 1, 1, 1]
|
61 |
-
|
62 |
-
# BIAS
|
63 |
-
if bias:
|
64 |
-
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
65 |
-
else:
|
66 |
-
self.register_parameter("bias", None)
|
67 |
-
|
68 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
69 |
-
return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)
|
70 |
-
|
71 |
-
|
72 |
-
def get_int_dtype(nbits: int) -> torch.dtype:
|
73 |
-
if nbits <= 8:
|
74 |
-
return torch.int8
|
75 |
-
if nbits <= 16:
|
76 |
-
return torch.int16
|
77 |
-
if nbits <= 32:
|
78 |
-
return torch.int32
|
79 |
-
if nbits <= 64:
|
80 |
-
return torch.int64
|
81 |
-
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
|
82 |
-
|
83 |
-
|
84 |
-
@torch.inference_mode()
|
85 |
-
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
|
86 |
-
return data.to(torch.int64) % (2**nbits)
|
87 |
-
|
88 |
-
|
89 |
-
@functools.lru_cache()
|
90 |
-
def maybe_script(fn: callable) -> callable:
|
91 |
-
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
|
92 |
-
using_tpu = bool(os.environ.get("TPU_NAME"))
|
93 |
-
# this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
|
94 |
-
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
|
95 |
-
return torch.jit.script(fn) if should_script else fn
|
96 |
-
|
97 |
-
|
98 |
-
@maybe_script
|
99 |
-
def _dequantize_weight(
|
100 |
-
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
|
101 |
-
) -> torch.Tensor:
|
102 |
-
"""
|
103 |
-
Decode float weights from quantization codes. Differentiable.
|
104 |
-
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
|
105 |
-
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
|
106 |
-
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
|
107 |
-
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
|
108 |
-
"""
|
109 |
-
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
|
110 |
-
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
|
111 |
-
out_features = num_out_groups * out_group_size
|
112 |
-
in_features = num_in_groups * in_group_size
|
113 |
-
codebook_offsets = torch.arange(
|
114 |
-
0, num_codebooks * codebook_size, codebook_size, device=codes.device
|
115 |
-
) # shape: [num_codebooks]
|
116 |
-
reconstructed_weight_flat = F.embedding_bag(
|
117 |
-
codes.flatten(0, -2) + codebook_offsets,
|
118 |
-
codebooks.flatten(0, 1).flatten(-2, -1),
|
119 |
-
mode="sum",
|
120 |
-
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
|
121 |
-
|
122 |
-
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
|
123 |
-
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
|
124 |
-
)
|
125 |
-
if scales is not None:
|
126 |
-
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
|
127 |
-
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
|
128 |
-
|
129 |
-
|
130 |
-
def forward_pass_quantized_linear(
|
131 |
-
input: torch.Tensor,
|
132 |
-
codes: torch.IntTensor,
|
133 |
-
codebooks: torch.Tensor,
|
134 |
-
scales: torch.Tensor,
|
135 |
-
bias: Optional[torch.Tensor],
|
136 |
-
) -> torch.Tensor:
|
137 |
-
if input.is_cuda:
|
138 |
-
return triton_matmul(input, codes, codebooks, scales, bias)
|
139 |
-
else:
|
140 |
-
dequantized_weight = _dequantize_weight(
|
141 |
-
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
|
142 |
-
codebooks,
|
143 |
-
scales,
|
144 |
-
)
|
145 |
-
return F.linear(input, dequantized_weight, bias)
|
146 |
-
|
147 |
-
|
148 |
-
@triton.autotune(
|
149 |
-
configs=[
|
150 |
-
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
|
151 |
-
for num_stages in (1, 2, 3, 4, 5)
|
152 |
-
for num_warps in (1, 2, 4, 8)
|
153 |
-
],
|
154 |
-
key=[
|
155 |
-
"in_features",
|
156 |
-
"out_features",
|
157 |
-
"num_codebooks",
|
158 |
-
"codebook_size",
|
159 |
-
"out_group_size",
|
160 |
-
"in_group_size",
|
161 |
-
"num_input_groups",
|
162 |
-
"num_input_groups_next_power_of_2",
|
163 |
-
"compute_in_fp32",
|
164 |
-
"has_bias",
|
165 |
-
],
|
166 |
-
)
|
167 |
-
@triton.jit
|
168 |
-
def _aqlm_gemv_simple(
|
169 |
-
input_vec_ptr,
|
170 |
-
output_vec_ptr,
|
171 |
-
codes_ptr,
|
172 |
-
codebooks_ptr,
|
173 |
-
scales_ptr,
|
174 |
-
bias_ptr,
|
175 |
-
in_features: tl.constexpr,
|
176 |
-
out_features: tl.constexpr,
|
177 |
-
num_codebooks: tl.constexpr,
|
178 |
-
codebook_size: tl.constexpr,
|
179 |
-
out_group_size: tl.constexpr,
|
180 |
-
in_group_size: tl.constexpr,
|
181 |
-
num_input_groups: tl.constexpr,
|
182 |
-
num_input_groups_next_power_of_2: tl.constexpr,
|
183 |
-
compute_in_fp32: tl.constexpr,
|
184 |
-
has_bias: tl.constexpr,
|
185 |
-
UNUSED: tl.constexpr,
|
186 |
-
):
|
187 |
-
# variables ending with "_i" mean "for i-th output unit"
|
188 |
-
pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
|
189 |
-
|
190 |
-
# Stage 1: load input data
|
191 |
-
input_vec = tl.load(
|
192 |
-
input_vec_ptr
|
193 |
-
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
|
194 |
-
+ tl.arange(0, in_group_size)[None, None, None, :],
|
195 |
-
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
|
196 |
-
)
|
197 |
-
# [in_features//in_group_size, 1, 1, group_size]
|
198 |
-
# Note: we could simply load input_vec then reshape
|
199 |
-
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
|
200 |
-
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
|
201 |
-
# , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
|
202 |
-
|
203 |
-
# Stage 2: load integer codes for the active row
|
204 |
-
# [in_features // in_group_size, num_codebooks]
|
205 |
-
codes_i_ptrs = (
|
206 |
-
codes_ptr
|
207 |
-
+ pid * num_input_groups * num_codebooks
|
208 |
-
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
|
209 |
-
+ tl.arange(0, num_codebooks)[None, :]
|
210 |
-
)
|
211 |
-
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
|
212 |
-
|
213 |
-
codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
|
214 |
-
codes_i = codes_i.to(tl.int32)
|
215 |
-
codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
|
216 |
-
# ^-- (because codes are int16 tensors that contain uint data)
|
217 |
-
|
218 |
-
# The following alternative does not work:
|
219 |
-
# codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codeboo
|
220 |
-
|
221 |
-
# shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
|
222 |
-
codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
|
223 |
-
# ^-- [in_group_size, num_codebooks]
|
224 |
-
|
225 |
-
# Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
|
226 |
-
# [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
227 |
-
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
|
228 |
-
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
|
229 |
-
weight_i_ptrs = (
|
230 |
-
codebooks_ptr
|
231 |
-
+ codes_i[:, :, None, None] * out_group_size * in_group_size
|
232 |
-
+ out_group_ix * in_group_size
|
233 |
-
+ in_group_ix
|
234 |
-
)
|
235 |
-
|
236 |
-
# Stage 4: reconstruct weights, multiply by inputs and write out
|
237 |
-
weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0)
|
238 |
-
if compute_in_fp32:
|
239 |
-
weights_i = weights_i.to(tl.float32)
|
240 |
-
input_vec = input_vec.to(tl.float32)
|
241 |
-
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
242 |
-
|
243 |
-
if out_group_size == 1:
|
244 |
-
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
|
245 |
-
output_i = tl.sum(weights_i * input_vec) * scale
|
246 |
-
if has_bias:
|
247 |
-
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
248 |
-
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
|
249 |
-
else:
|
250 |
-
output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
|
251 |
-
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
|
252 |
-
if has_bias:
|
253 |
-
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
|
254 |
-
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
|
255 |
-
|
256 |
-
|
257 |
-
def next_power_of_2(x):
|
258 |
-
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
259 |
-
|
260 |
-
|
261 |
-
def aqlm_gemv_simple(
|
262 |
-
input_vec: torch.Tensor,
|
263 |
-
codes_i16: torch.ShortTensor,
|
264 |
-
codebooks: torch.Tensor,
|
265 |
-
scales: torch.Tensor,
|
266 |
-
bias: Optional[torch.Tensor],
|
267 |
-
compute_in_fp32: bool = True,
|
268 |
-
):
|
269 |
-
|
270 |
-
device, dtype = codebooks.device, codebooks.dtype
|
271 |
-
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
|
272 |
-
in_features = input_vec.shape[1]
|
273 |
-
out_features = codes_i16.shape[0] * out_group_size
|
274 |
-
num_input_groups = codes_i16.shape[1]
|
275 |
-
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
|
276 |
-
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
|
277 |
-
assert in_features % in_group_size == 0
|
278 |
-
assert codebooks.shape[1] < 2**32
|
279 |
-
|
280 |
-
output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
|
281 |
-
# 1D launch kernel where each block computes output unit
|
282 |
-
grid = lambda META: (out_features // out_group_size,)
|
283 |
-
_aqlm_gemv_simple[grid](
|
284 |
-
input_vec,
|
285 |
-
output_vec,
|
286 |
-
codes_i16,
|
287 |
-
codebooks,
|
288 |
-
scales,
|
289 |
-
bias,
|
290 |
-
in_features,
|
291 |
-
out_features,
|
292 |
-
num_codebooks,
|
293 |
-
codebook_size,
|
294 |
-
out_group_size,
|
295 |
-
in_group_size,
|
296 |
-
num_input_groups,
|
297 |
-
next_power_of_2(num_input_groups),
|
298 |
-
compute_in_fp32,
|
299 |
-
bias is not None,
|
300 |
-
)
|
301 |
-
|
302 |
-
return output_vec
|
303 |
-
|
304 |
-
|
305 |
-
def aqlm_gemm_stupid(
|
306 |
-
input: torch.Tensor,
|
307 |
-
codes_i16: torch.ShortTensor,
|
308 |
-
codebooks: torch.Tensor,
|
309 |
-
scales: torch.Tensor,
|
310 |
-
bias: Optional[torch.Tensor],
|
311 |
-
compute_in_fp32: bool = True,
|
312 |
-
):
|
313 |
-
device, dtype = codebooks.device, codebooks.dtype
|
314 |
-
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
|
315 |
-
in_features = input.shape[1]
|
316 |
-
out_features = codes_i16.shape[0] * out_group_size
|
317 |
-
num_input_groups = codes_i16.shape[1]
|
318 |
-
assert input.ndim == 2
|
319 |
-
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
|
320 |
-
assert in_features % in_group_size == 0
|
321 |
-
assert codebooks.shape[1] < 2**32
|
322 |
-
|
323 |
-
output = torch.empty(input.shape[0], out_features, device=device, dtype=dtype)
|
324 |
-
for i in range(input.shape[0]):
|
325 |
-
# 1D launch kernel where each block computes output unit
|
326 |
-
grid = lambda META: (out_features // out_group_size,)
|
327 |
-
_aqlm_gemv_simple[grid](
|
328 |
-
input[i],
|
329 |
-
output[i],
|
330 |
-
codes_i16,
|
331 |
-
codebooks,
|
332 |
-
scales,
|
333 |
-
bias,
|
334 |
-
in_features,
|
335 |
-
out_features,
|
336 |
-
num_codebooks,
|
337 |
-
codebook_size,
|
338 |
-
out_group_size,
|
339 |
-
in_group_size,
|
340 |
-
num_input_groups,
|
341 |
-
next_power_of_2(num_input_groups),
|
342 |
-
compute_in_fp32,
|
343 |
-
bias is not None,
|
344 |
-
)
|
345 |
-
|
346 |
-
return output
|
347 |
-
|
348 |
-
|
349 |
-
def triton_matmul(
|
350 |
-
input: torch.Tensor,
|
351 |
-
codes: torch.IntTensor,
|
352 |
-
codebooks: torch.Tensor,
|
353 |
-
scales: torch.Tensor,
|
354 |
-
bias: Optional[torch.Tensor],
|
355 |
-
compute_in_fp32: bool = True,
|
356 |
-
) -> torch.Tensor:
|
357 |
-
input_shape = input.shape
|
358 |
-
input = input.reshape(-1, input_shape[-1])
|
359 |
-
|
360 |
-
if input.shape[0] == 1:
|
361 |
-
return aqlm_gemv_simple(
|
362 |
-
input,
|
363 |
-
codes,
|
364 |
-
codebooks,
|
365 |
-
scales,
|
366 |
-
bias,
|
367 |
-
compute_in_fp32,
|
368 |
-
).reshape(input_shape[:-1] + (-1,))
|
369 |
-
else:
|
370 |
-
return aqlm_gemm_stupid(
|
371 |
-
input,
|
372 |
-
codes,
|
373 |
-
codebooks,
|
374 |
-
scales,
|
375 |
-
bias,
|
376 |
-
compute_in_fp32,
|
377 |
-
).reshape(input_shape[:-1] + (-1,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_llama_aqlm.py
CHANGED
@@ -53,7 +53,7 @@ from transformers.utils import (
|
|
53 |
from transformers.utils.import_utils import is_torch_fx_available
|
54 |
|
55 |
from .configuration_llama_aqlm import LlamaConfig
|
56 |
-
from
|
57 |
|
58 |
if is_flash_attn_2_available():
|
59 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
@@ -246,9 +246,9 @@ class LlamaMLP(nn.Module):
|
|
246 |
self.config = config
|
247 |
self.hidden_size = config.hidden_size
|
248 |
self.intermediate_size = config.intermediate_size
|
249 |
-
self.gate_proj =
|
250 |
-
self.up_proj =
|
251 |
-
self.down_proj =
|
252 |
self.act_fn = ACT2FN[config.hidden_act]
|
253 |
|
254 |
def forward(self, x):
|
@@ -314,16 +314,16 @@ class LlamaAttention(nn.Module):
|
|
314 |
f" and `num_heads`: {self.num_heads})."
|
315 |
)
|
316 |
|
317 |
-
self.q_proj =
|
318 |
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
319 |
)
|
320 |
-
self.k_proj =
|
321 |
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
322 |
)
|
323 |
-
self.v_proj =
|
324 |
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
325 |
)
|
326 |
-
self.o_proj =
|
327 |
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
|
328 |
)
|
329 |
self._init_rope()
|
|
|
53 |
from transformers.utils.import_utils import is_torch_fx_available
|
54 |
|
55 |
from .configuration_llama_aqlm import LlamaConfig
|
56 |
+
from aqlm import QuantizedLinear
|
57 |
|
58 |
if is_flash_attn_2_available():
|
59 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
246 |
self.config = config
|
247 |
self.hidden_size = config.hidden_size
|
248 |
self.intermediate_size = config.intermediate_size
|
249 |
+
self.gate_proj = QuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
|
250 |
+
self.up_proj = QuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
|
251 |
+
self.down_proj = QuantizedLinear(self.intermediate_size, self.hidden_size, bias=False, **config.aqlm)
|
252 |
self.act_fn = ACT2FN[config.hidden_act]
|
253 |
|
254 |
def forward(self, x):
|
|
|
314 |
f" and `num_heads`: {self.num_heads})."
|
315 |
)
|
316 |
|
317 |
+
self.q_proj = QuantizedLinear(
|
318 |
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
319 |
)
|
320 |
+
self.k_proj = QuantizedLinear(
|
321 |
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
322 |
)
|
323 |
+
self.v_proj = QuantizedLinear(
|
324 |
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
|
325 |
)
|
326 |
+
self.o_proj = QuantizedLinear(
|
327 |
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
|
328 |
)
|
329 |
self._init_rope()
|