AppleSwing commited on
Commit
dd01425
1 Parent(s): b20ad66

fix calculation error of mbu

Browse files
src/backend/hflm_with_measurement.py CHANGED
@@ -312,11 +312,14 @@ class HFLMWithMeasurement(HFLM):
312
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
313
  generation_kwargs.pop("temperature")
314
 
315
- generation_kwargs.pop("is_gsm8k")
 
 
316
  context_length = context.shape[1]
317
 
318
  if not is_gsm8k:
319
  # build stopping criteria
 
320
  stopping_criteria = stop_sequences_criteria(
321
  self.tokenizer, stop, context.shape[1], context.shape[0]
322
  )
@@ -354,7 +357,6 @@ class HFLMWithMeasurement(HFLM):
354
 
355
  model_info = API.model_info(repo_id=self.pretrained, revision=self.revision)
356
  model_size_param = get_model_size(model_info=model_info, precision=self.precision)
357
- model_size = model_size_param * precision_bytes
358
 
359
  model_config = self.model.config
360
 
@@ -401,7 +403,7 @@ class HFLMWithMeasurement(HFLM):
401
  prefilling_time = stop_watch.prefilling_time / batch_size
402
  decoding_time = stop_watch.decoding_time / batch_size
403
  token_per_sec = output_length / decoding_time
404
- ach_mem_bw = (model_size / 1e9 + kv_size) * token_per_sec
405
 
406
  flops_per_token = 2 * model_size + 2 * n_layers * context_length * d_model
407
  peak_flops_single = get_peak_flops(get_gpu_details(), self.precision)
 
312
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
313
  generation_kwargs.pop("temperature")
314
 
315
+ if is_gsm8k:
316
+ generation_kwargs.pop("is_gsm8k")
317
+
318
  context_length = context.shape[1]
319
 
320
  if not is_gsm8k:
321
  # build stopping criteria
322
+ print("Using normal stopping criteria")
323
  stopping_criteria = stop_sequences_criteria(
324
  self.tokenizer, stop, context.shape[1], context.shape[0]
325
  )
 
357
 
358
  model_info = API.model_info(repo_id=self.pretrained, revision=self.revision)
359
  model_size_param = get_model_size(model_info=model_info, precision=self.precision)
 
360
 
361
  model_config = self.model.config
362
 
 
403
  prefilling_time = stop_watch.prefilling_time / batch_size
404
  decoding_time = stop_watch.decoding_time / batch_size
405
  token_per_sec = output_length / decoding_time
406
+ ach_mem_bw = (model_size * precision_bytes / 1e9 + kv_size) * token_per_sec
407
 
408
  flops_per_token = 2 * model_size + 2 * n_layers * context_length * d_model
409
  peak_flops_single = get_peak_flops(get_gpu_details(), self.precision)
src/utils.py CHANGED
@@ -31,6 +31,12 @@ PEAK_FLOPS_DICT = {
31
  "NVIDIA-H100-PCIe-80GB": 1513e12,
32
  "NVIDIA-RTX-A5000-24GB": 444.4e12
33
  },
 
 
 
 
 
 
34
  "8bit":{
35
  "NVIDIA-A100-PCIe-80GB": 1248e12,
36
  "NVIDIA-A100-SXM-80GB": 1248e12,
 
31
  "NVIDIA-H100-PCIe-80GB": 1513e12,
32
  "NVIDIA-RTX-A5000-24GB": 444.4e12
33
  },
34
+ "bfloat16":{
35
+ "NVIDIA-A100-PCIe-80GB": 624e12,
36
+ "NVIDIA-A100-SXM-80GB": 624e12,
37
+ "NVIDIA-H100-PCIe-80GB": 1513e12,
38
+ "NVIDIA-RTX-A5000-24GB": 444.4e12
39
+ },
40
  "8bit":{
41
  "NVIDIA-A100-PCIe-80GB": 1248e12,
42
  "NVIDIA-A100-SXM-80GB": 1248e12,