Andrei Panferov commited on
Commit
03ea233
1 Parent(s): f1a2023

inference lib

Browse files
Files changed (2) hide show
  1. inference.py +0 -377
  2. modeling_llama_aqlm.py +8 -8
inference.py DELETED
@@ -1,377 +0,0 @@
1
- """ This file serves as the single entry point to efficiently run FinalizedQuantizedLinear layers"""
2
- import functools
3
- import os
4
- from typing import Optional
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import triton
10
- import triton.language as tl
11
-
12
-
13
- class FinalizedQuantizedLinear(nn.Module):
14
- def __init__(
15
- self,
16
- in_features: int,
17
- out_features: int,
18
- in_group_size: int,
19
- out_group_size: int,
20
- num_codebooks: int,
21
- nbits_per_codebook: int,
22
- bias=True,
23
- device=None,
24
- dtype=None,
25
- ):
26
- factory_kwargs = {"device": device, "dtype": dtype}
27
- super().__init__()
28
- self.in_features = in_features
29
- self.out_features = out_features
30
-
31
- assert self.in_features % in_group_size == 0
32
- assert self.out_features % out_group_size == 0
33
- num_out_groups = out_features // out_group_size
34
- num_in_groups = in_features // in_group_size
35
- self.out_group_size, self.in_group_size = out_group_size, in_group_size
36
- self.num_codebooks = num_codebooks
37
- self.nbits_per_codebook = nbits_per_codebook
38
- self.codebook_size = 2**nbits_per_codebook
39
-
40
- # CODES & CODEBOOKS
41
- self.codebooks = nn.Parameter(
42
- torch.empty(
43
- (num_codebooks, self.codebook_size, out_group_size, in_group_size),
44
- **factory_kwargs,
45
- ),
46
- requires_grad=True,
47
- ) # [num_codebooks, codebook_size, out_group_size, in_group_size]
48
- self.codes = nn.Parameter(
49
- torch.empty(
50
- (num_out_groups, num_in_groups, num_codebooks),
51
- device=device,
52
- dtype=get_int_dtype(nbits_per_codebook),
53
- ),
54
- requires_grad=False,
55
- ) # [num_out_groups, num_in_groups, num_codebooks]
56
-
57
- # SCALES
58
- self.scales = nn.Parameter(
59
- torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True
60
- ) # [num_out_groups, num_in_groups, 1, 1] if scale_nbits > 0 else [num_out_groups, 1, 1, 1]
61
-
62
- # BIAS
63
- if bias:
64
- self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
65
- else:
66
- self.register_parameter("bias", None)
67
-
68
- def forward(self, input: torch.Tensor) -> torch.Tensor:
69
- return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)
70
-
71
-
72
- def get_int_dtype(nbits: int) -> torch.dtype:
73
- if nbits <= 8:
74
- return torch.int8
75
- if nbits <= 16:
76
- return torch.int16
77
- if nbits <= 32:
78
- return torch.int32
79
- if nbits <= 64:
80
- return torch.int64
81
- raise ValueError(f"No dtype available for {nbits}-bit codebooks")
82
-
83
-
84
- @torch.inference_mode()
85
- def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
86
- return data.to(torch.int64) % (2**nbits)
87
-
88
-
89
- @functools.lru_cache()
90
- def maybe_script(fn: callable) -> callable:
91
- """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
92
- using_tpu = bool(os.environ.get("TPU_NAME"))
93
- # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
94
- should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
95
- return torch.jit.script(fn) if should_script else fn
96
-
97
-
98
- @maybe_script
99
- def _dequantize_weight(
100
- codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
101
- ) -> torch.Tensor:
102
- """
103
- Decode float weights from quantization codes. Differentiable.
104
- :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
105
- :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
106
- :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
107
- :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
108
- """
109
- num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
110
- num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
111
- out_features = num_out_groups * out_group_size
112
- in_features = num_in_groups * in_group_size
113
- codebook_offsets = torch.arange(
114
- 0, num_codebooks * codebook_size, codebook_size, device=codes.device
115
- ) # shape: [num_codebooks]
116
- reconstructed_weight_flat = F.embedding_bag(
117
- codes.flatten(0, -2) + codebook_offsets,
118
- codebooks.flatten(0, 1).flatten(-2, -1),
119
- mode="sum",
120
- ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
121
-
122
- reconstructed_weight_groupwise = reconstructed_weight_flat.view(
123
- list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
124
- )
125
- if scales is not None:
126
- reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
127
- return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
128
-
129
-
130
- def forward_pass_quantized_linear(
131
- input: torch.Tensor,
132
- codes: torch.IntTensor,
133
- codebooks: torch.Tensor,
134
- scales: torch.Tensor,
135
- bias: Optional[torch.Tensor],
136
- ) -> torch.Tensor:
137
- if input.is_cuda:
138
- return triton_matmul(input, codes, codebooks, scales, bias)
139
- else:
140
- dequantized_weight = _dequantize_weight(
141
- unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
142
- codebooks,
143
- scales,
144
- )
145
- return F.linear(input, dequantized_weight, bias)
146
-
147
-
148
- @triton.autotune(
149
- configs=[
150
- triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
151
- for num_stages in (1, 2, 3, 4, 5)
152
- for num_warps in (1, 2, 4, 8)
153
- ],
154
- key=[
155
- "in_features",
156
- "out_features",
157
- "num_codebooks",
158
- "codebook_size",
159
- "out_group_size",
160
- "in_group_size",
161
- "num_input_groups",
162
- "num_input_groups_next_power_of_2",
163
- "compute_in_fp32",
164
- "has_bias",
165
- ],
166
- )
167
- @triton.jit
168
- def _aqlm_gemv_simple(
169
- input_vec_ptr,
170
- output_vec_ptr,
171
- codes_ptr,
172
- codebooks_ptr,
173
- scales_ptr,
174
- bias_ptr,
175
- in_features: tl.constexpr,
176
- out_features: tl.constexpr,
177
- num_codebooks: tl.constexpr,
178
- codebook_size: tl.constexpr,
179
- out_group_size: tl.constexpr,
180
- in_group_size: tl.constexpr,
181
- num_input_groups: tl.constexpr,
182
- num_input_groups_next_power_of_2: tl.constexpr,
183
- compute_in_fp32: tl.constexpr,
184
- has_bias: tl.constexpr,
185
- UNUSED: tl.constexpr,
186
- ):
187
- # variables ending with "_i" mean "for i-th output unit"
188
- pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
189
-
190
- # Stage 1: load input data
191
- input_vec = tl.load(
192
- input_vec_ptr
193
- + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
194
- + tl.arange(0, in_group_size)[None, None, None, :],
195
- mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
196
- )
197
- # [in_features//in_group_size, 1, 1, group_size]
198
- # Note: we could simply load input_vec then reshape
199
- # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
200
- # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
201
- # , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
202
-
203
- # Stage 2: load integer codes for the active row
204
- # [in_features // in_group_size, num_codebooks]
205
- codes_i_ptrs = (
206
- codes_ptr
207
- + pid * num_input_groups * num_codebooks
208
- + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
209
- + tl.arange(0, num_codebooks)[None, :]
210
- )
211
- codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
212
-
213
- codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
214
- codes_i = codes_i.to(tl.int32)
215
- codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
216
- # ^-- (because codes are int16 tensors that contain uint data)
217
-
218
- # The following alternative does not work:
219
- # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codeboo
220
-
221
- # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
222
- codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
223
- # ^-- [in_group_size, num_codebooks]
224
-
225
- # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
226
- # [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
227
- out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
228
- in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
229
- weight_i_ptrs = (
230
- codebooks_ptr
231
- + codes_i[:, :, None, None] * out_group_size * in_group_size
232
- + out_group_ix * in_group_size
233
- + in_group_ix
234
- )
235
-
236
- # Stage 4: reconstruct weights, multiply by inputs and write out
237
- weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0)
238
- if compute_in_fp32:
239
- weights_i = weights_i.to(tl.float32)
240
- input_vec = input_vec.to(tl.float32)
241
- # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
242
-
243
- if out_group_size == 1:
244
- scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
245
- output_i = tl.sum(weights_i * input_vec) * scale
246
- if has_bias:
247
- output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
248
- tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
249
- else:
250
- output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
251
- output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
252
- if has_bias:
253
- output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
254
- tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
255
-
256
-
257
- def next_power_of_2(x):
258
- return 1 if x == 0 else 2 ** (x - 1).bit_length()
259
-
260
-
261
- def aqlm_gemv_simple(
262
- input_vec: torch.Tensor,
263
- codes_i16: torch.ShortTensor,
264
- codebooks: torch.Tensor,
265
- scales: torch.Tensor,
266
- bias: Optional[torch.Tensor],
267
- compute_in_fp32: bool = True,
268
- ):
269
-
270
- device, dtype = codebooks.device, codebooks.dtype
271
- num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
272
- in_features = input_vec.shape[1]
273
- out_features = codes_i16.shape[0] * out_group_size
274
- num_input_groups = codes_i16.shape[1]
275
- assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
276
- assert scales.shape == (out_features // out_group_size, 1, 1, 1)
277
- assert in_features % in_group_size == 0
278
- assert codebooks.shape[1] < 2**32
279
-
280
- output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
281
- # 1D launch kernel where each block computes output unit
282
- grid = lambda META: (out_features // out_group_size,)
283
- _aqlm_gemv_simple[grid](
284
- input_vec,
285
- output_vec,
286
- codes_i16,
287
- codebooks,
288
- scales,
289
- bias,
290
- in_features,
291
- out_features,
292
- num_codebooks,
293
- codebook_size,
294
- out_group_size,
295
- in_group_size,
296
- num_input_groups,
297
- next_power_of_2(num_input_groups),
298
- compute_in_fp32,
299
- bias is not None,
300
- )
301
-
302
- return output_vec
303
-
304
-
305
- def aqlm_gemm_stupid(
306
- input: torch.Tensor,
307
- codes_i16: torch.ShortTensor,
308
- codebooks: torch.Tensor,
309
- scales: torch.Tensor,
310
- bias: Optional[torch.Tensor],
311
- compute_in_fp32: bool = True,
312
- ):
313
- device, dtype = codebooks.device, codebooks.dtype
314
- num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
315
- in_features = input.shape[1]
316
- out_features = codes_i16.shape[0] * out_group_size
317
- num_input_groups = codes_i16.shape[1]
318
- assert input.ndim == 2
319
- assert scales.shape == (out_features // out_group_size, 1, 1, 1)
320
- assert in_features % in_group_size == 0
321
- assert codebooks.shape[1] < 2**32
322
-
323
- output = torch.empty(input.shape[0], out_features, device=device, dtype=dtype)
324
- for i in range(input.shape[0]):
325
- # 1D launch kernel where each block computes output unit
326
- grid = lambda META: (out_features // out_group_size,)
327
- _aqlm_gemv_simple[grid](
328
- input[i],
329
- output[i],
330
- codes_i16,
331
- codebooks,
332
- scales,
333
- bias,
334
- in_features,
335
- out_features,
336
- num_codebooks,
337
- codebook_size,
338
- out_group_size,
339
- in_group_size,
340
- num_input_groups,
341
- next_power_of_2(num_input_groups),
342
- compute_in_fp32,
343
- bias is not None,
344
- )
345
-
346
- return output
347
-
348
-
349
- def triton_matmul(
350
- input: torch.Tensor,
351
- codes: torch.IntTensor,
352
- codebooks: torch.Tensor,
353
- scales: torch.Tensor,
354
- bias: Optional[torch.Tensor],
355
- compute_in_fp32: bool = True,
356
- ) -> torch.Tensor:
357
- input_shape = input.shape
358
- input = input.reshape(-1, input_shape[-1])
359
-
360
- if input.shape[0] == 1:
361
- return aqlm_gemv_simple(
362
- input,
363
- codes,
364
- codebooks,
365
- scales,
366
- bias,
367
- compute_in_fp32,
368
- ).reshape(input_shape[:-1] + (-1,))
369
- else:
370
- return aqlm_gemm_stupid(
371
- input,
372
- codes,
373
- codebooks,
374
- scales,
375
- bias,
376
- compute_in_fp32,
377
- ).reshape(input_shape[:-1] + (-1,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_llama_aqlm.py CHANGED
@@ -53,7 +53,7 @@ from transformers.utils import (
53
  from transformers.utils.import_utils import is_torch_fx_available
54
 
55
  from .configuration_llama_aqlm import LlamaConfig
56
- from .inference import FinalizedQuantizedLinear
57
 
58
  if is_flash_attn_2_available():
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -246,9 +246,9 @@ class LlamaMLP(nn.Module):
246
  self.config = config
247
  self.hidden_size = config.hidden_size
248
  self.intermediate_size = config.intermediate_size
249
- self.gate_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
250
- self.up_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
251
- self.down_proj = FinalizedQuantizedLinear(self.intermediate_size, self.hidden_size, bias=False, **config.aqlm)
252
  self.act_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, x):
@@ -314,16 +314,16 @@ class LlamaAttention(nn.Module):
314
  f" and `num_heads`: {self.num_heads})."
315
  )
316
 
317
- self.q_proj = FinalizedQuantizedLinear(
318
  self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
319
  )
320
- self.k_proj = FinalizedQuantizedLinear(
321
  self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
322
  )
323
- self.v_proj = FinalizedQuantizedLinear(
324
  self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
325
  )
326
- self.o_proj = FinalizedQuantizedLinear(
327
  self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
328
  )
329
  self._init_rope()
 
53
  from transformers.utils.import_utils import is_torch_fx_available
54
 
55
  from .configuration_llama_aqlm import LlamaConfig
56
+ from aqlm import QuantizedLinear
57
 
58
  if is_flash_attn_2_available():
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
246
  self.config = config
247
  self.hidden_size = config.hidden_size
248
  self.intermediate_size = config.intermediate_size
249
+ self.gate_proj = QuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
250
+ self.up_proj = QuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
251
+ self.down_proj = QuantizedLinear(self.intermediate_size, self.hidden_size, bias=False, **config.aqlm)
252
  self.act_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, x):
 
314
  f" and `num_heads`: {self.num_heads})."
315
  )
316
 
317
+ self.q_proj = QuantizedLinear(
318
  self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
319
  )
320
+ self.k_proj = QuantizedLinear(
321
  self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
322
  )
323
+ self.v_proj = QuantizedLinear(
324
  self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
325
  )
326
+ self.o_proj = QuantizedLinear(
327
  self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
328
  )
329
  self._init_rope()