Andrei Panferov commited on
Commit
115e749
·
1 Parent(s): dfb8eb3

newer inference

Browse files
Files changed (2) hide show
  1. config.json +77 -19
  2. inference.py +73 -23
config.json CHANGED
@@ -1,34 +1,92 @@
1
  {
2
- "architectures": [
3
- "LlamaForCausalLM_AQLM"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_llama_aqlm.LlamaConfig",
7
- "AutoModelForCausalLM": "modeling_llama_aqlm.LlamaForCausalLM"
8
- },
9
- "bos_token_id": 1,
10
- "eos_token_id": 2,
11
- "hidden_act": "silu",
12
  "hidden_size": 4096,
13
- "initializer_range": 0.02,
14
  "intermediate_size": 11008,
15
- "max_position_embeddings": 4096,
16
- "model_type": "llama_aqlm",
17
- "num_attention_heads": 32,
18
  "num_hidden_layers": 32,
 
19
  "num_key_value_heads": 32,
20
- "pretraining_tp": 1,
 
21
  "rms_norm_eps": 1e-05,
 
 
 
22
  "rope_scaling": null,
23
- "tie_word_embeddings": false,
 
 
 
 
 
24
  "torch_dtype": "float16",
25
- "transformers_version": "4.31.0.dev0",
26
- "use_cache": true,
27
- "vocab_size": 32000,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  "aqlm": {
29
  "nbits_per_codebook": 16,
30
  "num_codebooks": 1,
31
  "out_group_size": 1,
32
  "in_group_size": 8
 
 
 
 
 
33
  }
34
  }
 
1
  {
2
+ "vocab_size": 32000,
3
+ "max_position_embeddings": 4096,
 
 
 
 
 
 
 
 
4
  "hidden_size": 4096,
 
5
  "intermediate_size": 11008,
 
 
 
6
  "num_hidden_layers": 32,
7
+ "num_attention_heads": 32,
8
  "num_key_value_heads": 32,
9
+ "hidden_act": "silu",
10
+ "initializer_range": 0.02,
11
  "rms_norm_eps": 1e-05,
12
+ "pretraining_tp": 1,
13
+ "use_cache": true,
14
+ "rope_theta": 10000.0,
15
  "rope_scaling": null,
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "return_dict": true,
19
+ "output_hidden_states": false,
20
+ "output_attentions": false,
21
+ "torchscript": false,
22
  "torch_dtype": "float16",
23
+ "use_bfloat16": false,
24
+ "tf_legacy_loss": false,
25
+ "pruned_heads": {},
26
+ "tie_word_embeddings": false,
27
+ "is_encoder_decoder": false,
28
+ "is_decoder": false,
29
+ "cross_attention_hidden_size": null,
30
+ "add_cross_attention": false,
31
+ "tie_encoder_decoder": false,
32
+ "max_length": 20,
33
+ "min_length": 0,
34
+ "do_sample": false,
35
+ "early_stopping": false,
36
+ "num_beams": 1,
37
+ "num_beam_groups": 1,
38
+ "diversity_penalty": 0.0,
39
+ "temperature": 1.0,
40
+ "top_k": 50,
41
+ "top_p": 1.0,
42
+ "typical_p": 1.0,
43
+ "repetition_penalty": 1.0,
44
+ "length_penalty": 1.0,
45
+ "no_repeat_ngram_size": 0,
46
+ "encoder_no_repeat_ngram_size": 0,
47
+ "bad_words_ids": null,
48
+ "num_return_sequences": 1,
49
+ "chunk_size_feed_forward": 0,
50
+ "output_scores": false,
51
+ "return_dict_in_generate": false,
52
+ "forced_bos_token_id": null,
53
+ "forced_eos_token_id": null,
54
+ "remove_invalid_values": false,
55
+ "exponential_decay_length_penalty": null,
56
+ "suppress_tokens": null,
57
+ "begin_suppress_tokens": null,
58
+ "architectures": [
59
+ "LlamaForCausalLM"
60
+ ],
61
+ "finetuning_task": null,
62
+ "id2label": {
63
+ "0": "LABEL_0",
64
+ "1": "LABEL_1"
65
+ },
66
+ "label2id": {
67
+ "LABEL_0": 0,
68
+ "LABEL_1": 1
69
+ },
70
+ "tokenizer_class": null,
71
+ "prefix": null,
72
+ "bos_token_id": 1,
73
+ "pad_token_id": null,
74
+ "eos_token_id": 2,
75
+ "sep_token_id": null,
76
+ "decoder_start_token_id": null,
77
+ "task_specific_params": null,
78
+ "problem_type": null,
79
+ "_name_or_path": "",
80
+ "transformers_version": "4.36.2",
81
  "aqlm": {
82
  "nbits_per_codebook": 16,
83
  "num_codebooks": 1,
84
  "out_group_size": 1,
85
  "in_group_size": 8
86
+ },
87
+ "model_type": "llama_aqlm",
88
+ "auto_map": {
89
+ "AutoConfig": "configuration_llama_aqlm.LlamaConfig",
90
+ "AutoModelForCausalLM": "modeling_llama_aqlm.LlamaForCausalLM"
91
  }
92
  }
inference.py CHANGED
@@ -135,7 +135,7 @@ def forward_pass_quantized_linear(
135
  bias: Optional[torch.Tensor],
136
  ) -> torch.Tensor:
137
  if input.is_cuda:
138
- return aqlm_gemm_stupid(input, codes, codebooks, scales, bias)
139
  else:
140
  dequantized_weight = _dequantize_weight(
141
  unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
@@ -160,7 +160,6 @@ def forward_pass_quantized_linear(
160
  "in_group_size",
161
  "num_input_groups",
162
  "num_input_groups_next_power_of_2",
163
- "has_bias",
164
  "compute_in_fp32",
165
  ],
166
  )
@@ -168,7 +167,7 @@ def forward_pass_quantized_linear(
168
  def _aqlm_gemv_simple(
169
  input_vec_ptr,
170
  output_vec_ptr,
171
- codes_i16_ptr,
172
  codebooks_ptr,
173
  scales_ptr,
174
  bias_ptr,
@@ -181,7 +180,6 @@ def _aqlm_gemv_simple(
181
  num_input_groups: tl.constexpr,
182
  num_input_groups_next_power_of_2: tl.constexpr,
183
  compute_in_fp32: tl.constexpr,
184
- has_bias: tl.constexpr,
185
  UNUSED: tl.constexpr,
186
  ):
187
  # variables ending with "_i" mean "for i-th output unit"
@@ -203,7 +201,7 @@ def _aqlm_gemv_simple(
203
  # Stage 2: load integer codes for the active row
204
  # [in_features // in_group_size, num_codebooks]
205
  codes_i_ptrs = (
206
- codes_i16_ptr
207
  + pid * num_input_groups * num_codebooks
208
  + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
209
  + tl.arange(0, num_codebooks)[None, :]
@@ -211,15 +209,12 @@ def _aqlm_gemv_simple(
211
  codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
212
 
213
  codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
214
- if codes_i.dtype == tl.int16:
215
- codes_i = codes_i.to(tl.int32)
216
- codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
217
- # ^-- (because codes are int16 tensors that contain uint data)
218
 
219
- # The following alternative does not work:
220
- # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
221
- else:
222
- codes_i = codes_i.to(tl.int32)
223
 
224
  # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
225
  codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
@@ -280,7 +275,7 @@ def aqlm_gemv_simple(
280
  assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
281
  assert scales.shape == (out_features // out_group_size, 1, 1, 1)
282
  assert in_features % in_group_size == 0
283
- assert codebooks.shape[1] == 2**16
284
 
285
  output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
286
  # 1D launch kernel where each block computes output unit
@@ -301,7 +296,6 @@ def aqlm_gemv_simple(
301
  num_input_groups,
302
  next_power_of_2(num_input_groups),
303
  compute_in_fp32,
304
- bias is not None,
305
  )
306
 
307
  return output_vec
@@ -315,11 +309,67 @@ def aqlm_gemm_stupid(
315
  bias: Optional[torch.Tensor],
316
  compute_in_fp32: bool = True,
317
  ):
318
- original_shape = input.shape
319
- input = input.reshape(-1, original_shape[-1])
320
- return torch.cat(
321
- [
322
- aqlm_gemv_simple(input_vec.unsqueeze(0), codes_i16, codebooks, scales, bias, compute_in_fp32)
323
- for input_vec in input
324
- ]
325
- ).reshape(original_shape[:-1] + (-1,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  bias: Optional[torch.Tensor],
136
  ) -> torch.Tensor:
137
  if input.is_cuda:
138
+ return triton_matmul(input, codes, codebooks, scales, bias)
139
  else:
140
  dequantized_weight = _dequantize_weight(
141
  unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
 
160
  "in_group_size",
161
  "num_input_groups",
162
  "num_input_groups_next_power_of_2",
 
163
  "compute_in_fp32",
164
  ],
165
  )
 
167
  def _aqlm_gemv_simple(
168
  input_vec_ptr,
169
  output_vec_ptr,
170
+ codes_ptr,
171
  codebooks_ptr,
172
  scales_ptr,
173
  bias_ptr,
 
180
  num_input_groups: tl.constexpr,
181
  num_input_groups_next_power_of_2: tl.constexpr,
182
  compute_in_fp32: tl.constexpr,
 
183
  UNUSED: tl.constexpr,
184
  ):
185
  # variables ending with "_i" mean "for i-th output unit"
 
201
  # Stage 2: load integer codes for the active row
202
  # [in_features // in_group_size, num_codebooks]
203
  codes_i_ptrs = (
204
+ codes_ptr
205
  + pid * num_input_groups * num_codebooks
206
  + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
207
  + tl.arange(0, num_codebooks)[None, :]
 
209
  codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
210
 
211
  codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
212
+ codes_i = codes_i.to(tl.int32)
213
+ codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
214
+ # ^-- (because codes are int16 tensors that contain uint data)
 
215
 
216
+ # The following alternative does not work:
217
+ # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codeboo
 
 
218
 
219
  # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
220
  codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
 
275
  assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
276
  assert scales.shape == (out_features // out_group_size, 1, 1, 1)
277
  assert in_features % in_group_size == 0
278
+ assert codebooks.shape[1] < 2**32
279
 
280
  output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
281
  # 1D launch kernel where each block computes output unit
 
296
  num_input_groups,
297
  next_power_of_2(num_input_groups),
298
  compute_in_fp32,
 
299
  )
300
 
301
  return output_vec
 
309
  bias: Optional[torch.Tensor],
310
  compute_in_fp32: bool = True,
311
  ):
312
+ device, dtype = codebooks.device, codebooks.dtype
313
+ num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
314
+ in_features = input.shape[1]
315
+ out_features = codes_i16.shape[0] * out_group_size
316
+ num_input_groups = codes_i16.shape[1]
317
+ assert input.ndim == 2
318
+ assert scales.shape == (out_features // out_group_size, 1, 1, 1)
319
+ assert in_features % in_group_size == 0
320
+ assert codebooks.shape[1] < 2**32
321
+
322
+ output = torch.empty(input.shape[0], out_features, device=device, dtype=dtype)
323
+ for i in range(input.shape[0]):
324
+ # 1D launch kernel where each block computes output unit
325
+ grid = lambda META: (out_features // out_group_size,)
326
+ _aqlm_gemv_simple[grid](
327
+ input[i],
328
+ output[i],
329
+ codes_i16,
330
+ codebooks,
331
+ scales,
332
+ bias,
333
+ in_features,
334
+ out_features,
335
+ num_codebooks,
336
+ codebook_size,
337
+ out_group_size,
338
+ in_group_size,
339
+ num_input_groups,
340
+ next_power_of_2(num_input_groups),
341
+ compute_in_fp32,
342
+ )
343
+
344
+ return output
345
+
346
+
347
+ def triton_matmul(
348
+ input: torch.Tensor,
349
+ codes: torch.IntTensor,
350
+ codebooks: torch.Tensor,
351
+ scales: torch.Tensor,
352
+ bias: Optional[torch.Tensor],
353
+ compute_in_fp32: bool = True,
354
+ ) -> torch.Tensor:
355
+ input_shape = input.shape
356
+ input = input.reshape(-1, input_shape[-1])
357
+
358
+ if input.shape[0] == 1:
359
+ return aqlm_gemv_simple(
360
+ input,
361
+ codes,
362
+ codebooks,
363
+ scales,
364
+ bias,
365
+ compute_in_fp32,
366
+ ).reshape(input_shape[:-1] + (-1,))
367
+ else:
368
+ return aqlm_gemm_stupid(
369
+ input,
370
+ codes,
371
+ codebooks,
372
+ scales,
373
+ bias,
374
+ compute_in_fp32,
375
+ ).reshape(input_shape[:-1] + (-1,))