AppleSwing commited on
Commit
0be51d4
1 Parent(s): f494d8b

Add GSM8K (#27)

Browse files

- Add GSM8K (900a631539af9b1ea4ccf861e170bcdeae8d46fe)
- Merge branch 'main' into pr/27 (e82a7af1535e766b16314ac4eefb0d2fa1fbaee4)
- Delete gsm8k yamls (f38163c1a9a66242aa5baf0ba91bd786b8d802d3)
- Fix some bugs (9ffef81821400a94b5d4c08eddb5268944b26e7f)
- Fix bugs on wrappers and add quantization requirement (28b60907c7a9ed112ee151c6eadb22d1e7074116)
- Fix bugs in gsm8k (4045483a84607da8b1c2505dc7f1ba2bdd407f47)

backend-cli.py CHANGED
@@ -17,7 +17,7 @@ from src.backend.manage_requests import EvalRequest
17
  from src.leaderboard.read_evals import EvalResult
18
 
19
  from src.envs import QUEUE_REPO, RESULTS_REPO, API, DEBUG_QUEUE_REPO, DEBUG_RESULTS_REPO
20
- from src.utils import my_snapshot_download, analyze_gpu_stats, parse_nvidia_smi, monitor_gpus
21
 
22
  from src.leaderboard.read_evals import get_raw_eval_results
23
 
@@ -28,6 +28,8 @@ import time
28
  import pprint
29
  import logging
30
 
 
 
31
 
32
  # Configure the root logger
33
  logging.basicConfig(
@@ -42,6 +44,20 @@ eval_logger = logging.getLogger("lm-eval")
42
  # Explicitly set the level for 'lm-eval' logger to WARNING
43
  eval_logger.setLevel(logging.WARNING)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def my_set_eval_request(api, eval_request, set_to_status, hf_repo, local_dir):
47
  for i in range(10):
@@ -126,9 +142,6 @@ def request_to_result_name(request: EvalRequest) -> str:
126
  def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[int] = None) -> dict:
127
  batch_size = 1
128
  batch_size = eval_request.batch_size
129
-
130
- if args.debug:
131
- RESULTS_REPO = DEBUG_RESULTS_REPO
132
 
133
  init_gpu_info = analyze_gpu_stats(parse_nvidia_smi())
134
  # if init_gpu_info['Mem(M)'] > 500:
@@ -137,6 +150,12 @@ def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[in
137
  stop_event = threading.Event()
138
  monitor_thread = threading.Thread(target=monitor_gpus, args=(stop_event, 5, gpu_stats_list))
139
  monitor_thread.start()
 
 
 
 
 
 
140
 
141
  try:
142
  results = run_evaluation(
@@ -198,6 +217,8 @@ def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[in
198
  repo_id=RESULTS_REPO,
199
  repo_type="dataset",
200
  )
 
 
201
  return results
202
 
203
 
@@ -366,21 +387,7 @@ def maybe_refresh_results(thr: int, hard_task_lst: Optional[list[str]] = None) -
366
 
367
  return False
368
 
369
-
370
- def get_gpu_details():
371
- gpus = GPUtil.getGPUs()
372
- gpu = gpus[0]
373
- name = gpu.name.replace(" ", "-")
374
- # Convert memory from MB to GB and round to nearest whole number
375
- memory_gb = round(gpu.memoryTotal / 1024)
376
- memory = f"{memory_gb}GB"
377
- formatted_name = f"{name}-{memory}"
378
- return formatted_name
379
-
380
  def process_pending_requests() -> bool:
381
- if args.debug:
382
- QUEUE_REPO = DEBUG_QUEUE_REPO
383
-
384
  sanity_checks()
385
  print("Processing pending requests")
386
  current_pending_status = [PENDING_STATUS]
@@ -443,13 +450,14 @@ def get_args():
443
  parser = argparse.ArgumentParser(description="Run the backend")
444
  parser.add_argument("--debug", action="store_true", help="Run in debug mode")
445
  # debug parameters
446
- parser.add_argument("--task", type=str, default="selfcheckgpt,mmlu", help="Task to debug")
447
  parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1,mistralai/Mixtral-8x7B-v0.1", help="Model to debug")
448
  parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
449
  parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
450
  parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
451
  parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB",
452
  help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24GB; NVIDIA-H100-PCIe-80GB")
 
453
  return parser.parse_args()
454
 
455
 
@@ -457,7 +465,7 @@ if __name__ == "__main__":
457
  args = get_args()
458
  local_debug = args.debug
459
  # debug specific task by ping
460
- if local_debug:
461
  # debug_model_names = [args.model] # Use model from arguments
462
  # debug_task_name = [args.task] # Use task from arguments
463
  debug_model_names = args.model.split(",")
@@ -471,42 +479,60 @@ if __name__ == "__main__":
471
  task_name = task.benchmark
472
  if task_name not in debug_task_name:
473
  continue
474
- try:
475
- eval_request = EvalRequest(
476
- model=debug_model_name,
477
- private=False,
478
- status="",
479
- json_filepath="",
480
- precision=precision, # Use precision from arguments
481
- inference_framework=args.inference_framework, # Use inference framework from arguments
482
- gpu_type=args.gpu_type
483
- )
484
- curr_gpu_type = get_gpu_details()
485
- if eval_request.gpu_type != curr_gpu_type:
486
- print(f"GPU type mismatch: {eval_request.gpu_type} vs {curr_gpu_type}")
487
- raise Exception("GPU type mismatch")
488
- results = process_evaluation(task, eval_request, limit=args.limit)
489
- except Exception as e:
490
- print(f"debug running error: {e}")
491
- else:
 
 
492
  while True:
493
  res = False
494
-
495
  # if random.randint(0, 10) == 0:
496
  res = process_pending_requests()
497
  print(f"waiting for 60 seconds")
498
  time.sleep(60)
499
-
500
  # if res is False:
501
  # if random.randint(0, 5) == 0:
502
  # res = maybe_refresh_results(100)
503
  # else:
504
  # res = process_finished_requests(100)
505
-
506
  # time.sleep(60)
507
-
508
  # if res is False:
509
  # if random.randint(0, 5) == 0:
510
  # res = maybe_refresh_results(0)
511
  # else:
512
  # res = process_finished_requests(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from src.leaderboard.read_evals import EvalResult
18
 
19
  from src.envs import QUEUE_REPO, RESULTS_REPO, API, DEBUG_QUEUE_REPO, DEBUG_RESULTS_REPO
20
+ from src.utils import my_snapshot_download, analyze_gpu_stats, parse_nvidia_smi, monitor_gpus, get_gpu_details
21
 
22
  from src.leaderboard.read_evals import get_raw_eval_results
23
 
 
28
  import pprint
29
  import logging
30
 
31
+ from lm_eval.filters.extraction import RegexFilter
32
+
33
 
34
  # Configure the root logger
35
  logging.basicConfig(
 
44
  # Explicitly set the level for 'lm-eval' logger to WARNING
45
  eval_logger.setLevel(logging.WARNING)
46
 
47
+ def tuple_input_decorator(func):
48
+ def wrapper(self, resps, docs):
49
+ stripped_resps = [[resp_data[0] for resp_data in group] for group in resps]
50
+
51
+ filtered_resps = func(self, stripped_resps, docs)
52
+
53
+ combined_resps = []
54
+ for original_group, new_group in zip(resps, filtered_resps):
55
+ combined_group = [(new_resp,) + rest_of_data[1:] for new_resp, rest_of_data in zip(new_group, original_group)]
56
+ combined_resps.append(combined_group)
57
+
58
+ return combined_resps
59
+ return wrapper
60
+
61
 
62
  def my_set_eval_request(api, eval_request, set_to_status, hf_repo, local_dir):
63
  for i in range(10):
 
142
  def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[int] = None) -> dict:
143
  batch_size = 1
144
  batch_size = eval_request.batch_size
 
 
 
145
 
146
  init_gpu_info = analyze_gpu_stats(parse_nvidia_smi())
147
  # if init_gpu_info['Mem(M)'] > 500:
 
150
  stop_event = threading.Event()
151
  monitor_thread = threading.Thread(target=monitor_gpus, args=(stop_event, 5, gpu_stats_list))
152
  monitor_thread.start()
153
+
154
+ original_apply = RegexFilter.apply
155
+ if task.benchmark in ["gsm8k", "gsm8k_cot", "gsm8k_cot_self_consistency", "gsm8k_custom"]:
156
+ RegexFilter.apply = tuple_input_decorator(RegexFilter.apply)
157
+ else:
158
+ RegexFilter.apply = original_apply
159
 
160
  try:
161
  results = run_evaluation(
 
217
  repo_id=RESULTS_REPO,
218
  repo_type="dataset",
219
  )
220
+
221
+ RegexFilter.apply = original_apply
222
  return results
223
 
224
 
 
387
 
388
  return False
389
 
 
 
 
 
 
 
 
 
 
 
 
390
  def process_pending_requests() -> bool:
 
 
 
391
  sanity_checks()
392
  print("Processing pending requests")
393
  current_pending_status = [PENDING_STATUS]
 
450
  parser = argparse.ArgumentParser(description="Run the backend")
451
  parser.add_argument("--debug", action="store_true", help="Run in debug mode")
452
  # debug parameters
453
+ parser.add_argument("--task", type=str, default="selfcheckgpt,mmlu, gsm8k", help="Task to debug")
454
  parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1,mistralai/Mixtral-8x7B-v0.1", help="Model to debug")
455
  parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
456
  parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
457
  parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
458
  parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB",
459
  help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24GB; NVIDIA-H100-PCIe-80GB")
460
+ parser.add_argument("--debug_repo", action="store_true", help="Use debug repo")
461
  return parser.parse_args()
462
 
463
 
 
465
  args = get_args()
466
  local_debug = args.debug
467
  # debug specific task by ping
468
+ if local_debug and not args.debug_repo:
469
  # debug_model_names = [args.model] # Use model from arguments
470
  # debug_task_name = [args.task] # Use task from arguments
471
  debug_model_names = args.model.split(",")
 
479
  task_name = task.benchmark
480
  if task_name not in debug_task_name:
481
  continue
482
+ # try:
483
+ eval_request = EvalRequest(
484
+ model=debug_model_name,
485
+ private=False,
486
+ status="",
487
+ json_filepath="",
488
+ precision=precision, # Use precision from arguments
489
+ inference_framework=args.inference_framework, # Use inference framework from arguments
490
+ gpu_type=args.gpu_type
491
+ )
492
+ curr_gpu_type = get_gpu_details()
493
+ if eval_request.gpu_type != curr_gpu_type:
494
+ print(f"GPU type mismatch: {eval_request.gpu_type} vs {curr_gpu_type}")
495
+ raise Exception("GPU type mismatch")
496
+ results = process_evaluation(task, eval_request, limit=args.limit)
497
+ # except Exception as e:
498
+ # print(f"debug running error: {e}")
499
+ elif local_debug and args.debug_repo:
500
+ QUEUE_REPO = DEBUG_QUEUE_REPO
501
+ RESULTS_REPO = DEBUG_RESULTS_REPO
502
  while True:
503
  res = False
 
504
  # if random.randint(0, 10) == 0:
505
  res = process_pending_requests()
506
  print(f"waiting for 60 seconds")
507
  time.sleep(60)
 
508
  # if res is False:
509
  # if random.randint(0, 5) == 0:
510
  # res = maybe_refresh_results(100)
511
  # else:
512
  # res = process_finished_requests(100)
 
513
  # time.sleep(60)
 
514
  # if res is False:
515
  # if random.randint(0, 5) == 0:
516
  # res = maybe_refresh_results(0)
517
  # else:
518
  # res = process_finished_requests(0)
519
+ elif not local_debug and not args.debug_repo:
520
+ while True:
521
+ res = False
522
+ # if random.randint(0, 10) == 0:
523
+ res = process_pending_requests()
524
+ print(f"waiting for 60 seconds")
525
+ time.sleep(60)
526
+ # if res is False:
527
+ # if random.randint(0, 5) == 0:
528
+ # res = maybe_refresh_results(100)
529
+ # else:
530
+ # res = process_finished_requests(100)
531
+ # time.sleep(60)
532
+ # if res is False:
533
+ # if random.randint(0, 5) == 0:
534
+ # res = maybe_refresh_results(0)
535
+ # else:
536
+ # res = process_finished_requests(0)
537
+ else:
538
+ raise Exception("Cannot use debug_repo without local debug flag")
requirements.txt CHANGED
@@ -30,4 +30,5 @@ evaluate
30
  spacy==3.7.4
31
  selfcheckgpt
32
  immutabledict
33
- gputil
 
 
30
  spacy==3.7.4
31
  selfcheckgpt
32
  immutabledict
33
+ gputil
34
+ bitsandbytes
src/backend/envs.py CHANGED
@@ -57,6 +57,7 @@ class Tasks(Enum):
57
 
58
  # task20 = Task("race", "acc", "RACE", 0)
59
  task21 = Task("mmlu", "acc", "MMLU", 5)
 
60
 
61
 
62
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
 
57
 
58
  # task20 = Task("race", "acc", "RACE", 0)
59
  task21 = Task("mmlu", "acc", "MMLU", 5)
60
+ task22 = Task("gsm8k_custom", "em", "GSM8K", 5)
61
 
62
 
63
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
src/backend/hflm_with_measurement.py CHANGED
@@ -295,6 +295,8 @@ class HFLMWithMeasurement(HFLM):
295
  # and we don't want a warning from HF
296
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
297
  do_sample = generation_kwargs.get("do_sample", None)
 
 
298
 
299
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
300
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
@@ -302,22 +304,40 @@ class HFLMWithMeasurement(HFLM):
302
 
303
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
304
  generation_kwargs.pop("temperature")
 
 
 
 
305
  # build stopping criteria
306
- stopping_criteria = stop_sequences_criteria(
307
- self.tokenizer, stop, context.shape[1], context.shape[0]
308
- )
309
- stop_watch = StopWatch(self.tokenizer)
310
- start = time()
311
- res = self.model.generate(
312
- input_ids=context,
313
- max_length=max_length,
314
- stopping_criteria=stopping_criteria,
315
- pad_token_id=self.tokenizer.pad_token_id,
316
- use_cache=True,
317
- streamer=stop_watch,
318
- **generation_kwargs,
319
- )
320
- end = time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  batch_size = context.shape[0]
323
  output_length = stop_watch.decoding_iterations
@@ -408,6 +428,11 @@ class HFLMWithMeasurement(HFLM):
408
  until = [eos]
409
  else:
410
  until.append(eos)
 
 
 
 
 
411
  if "max_gen_toks" in kwargs.keys():
412
  max_gen_toks = kwargs.pop("max_gen_toks")
413
  else:
@@ -427,6 +452,8 @@ class HFLMWithMeasurement(HFLM):
427
  left_truncate_len=max_ctx_len,
428
  truncation=self.truncation,
429
  )
 
 
430
  context_enc = context_enc.to(self.device)
431
  attn_masks = attn_masks.to(self.device)
432
 
@@ -445,16 +472,18 @@ class HFLMWithMeasurement(HFLM):
445
  for cont_toks, context in zip(cont_toks_list, contexts):
446
  # discard context + left-padding toks if using causal decoder-only LM
447
  if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
 
448
  cont_toks = cont_toks[context_enc.shape[1] :]
449
-
450
  s = self.tok_decode(cont_toks)
451
 
452
  # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
453
- for term in until:
454
- if len(term) > 0:
455
- # ignore '' separator,
456
- # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
457
- s = s.split(term)[0]
 
458
 
459
  res.append((s, end_to_end_time, prefilling_time, token_per_sec))
460
 
 
295
  # and we don't want a warning from HF
296
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
297
  do_sample = generation_kwargs.get("do_sample", None)
298
+
299
+ is_gsm8k = generation_kwargs.get("is_gsm8k", False)
300
 
301
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
302
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
 
304
 
305
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
306
  generation_kwargs.pop("temperature")
307
+
308
+ generation_kwargs.pop("is_gsm8k")
309
+
310
+ if not is_gsm8k:
311
  # build stopping criteria
312
+ stopping_criteria = stop_sequences_criteria(
313
+ self.tokenizer, stop, context.shape[1], context.shape[0]
314
+ )
315
+ stop_watch = StopWatch(self.tokenizer)
316
+ start = time()
317
+ res = self.model.generate(
318
+ input_ids=context,
319
+ max_length=max_length,
320
+ stopping_criteria=stopping_criteria,
321
+ pad_token_id=self.tokenizer.pad_token_id,
322
+ use_cache=True,
323
+ streamer=stop_watch,
324
+ **generation_kwargs,
325
+ )
326
+ end = time()
327
+ else:
328
+ # print("Using GSM8K")
329
+ stop_watch = StopWatch(self.tokenizer)
330
+ start = time()
331
+ res = self.model.generate(
332
+ input_ids=context,
333
+ max_length=max_length,
334
+ eos_token_id=stop,
335
+ pad_token_id=self.tokenizer.pad_token_id,
336
+ use_cache=True,
337
+ streamer=stop_watch,
338
+ **generation_kwargs,
339
+ )
340
+ end = time()
341
 
342
  batch_size = context.shape[0]
343
  output_length = stop_watch.decoding_iterations
 
428
  until = [eos]
429
  else:
430
  until.append(eos)
431
+
432
+ is_gsm8k = kwargs.get("is_gsm8k", False)
433
+ if is_gsm8k:
434
+ until = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
435
+
436
  if "max_gen_toks" in kwargs.keys():
437
  max_gen_toks = kwargs.pop("max_gen_toks")
438
  else:
 
452
  left_truncate_len=max_ctx_len,
453
  truncation=self.truncation,
454
  )
455
+
456
+ # print("context: ", self.tok_decode(context_enc[0]))
457
  context_enc = context_enc.to(self.device)
458
  attn_masks = attn_masks.to(self.device)
459
 
 
472
  for cont_toks, context in zip(cont_toks_list, contexts):
473
  # discard context + left-padding toks if using causal decoder-only LM
474
  if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
475
+ # print("After Generation: ", self.tok_decode(cont_toks))
476
  cont_toks = cont_toks[context_enc.shape[1] :]
477
+
478
  s = self.tok_decode(cont_toks)
479
 
480
  # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
481
+ if not is_gsm8k:
482
+ for term in until:
483
+ if len(term) > 0:
484
+ # ignore '' separator,
485
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
486
+ s = s.split(term)[0]
487
 
488
  res.append((s, end_to_end_time, prefilling_time, token_per_sec))
489
 
src/backend/tasks/gsm8k/gsm8k-custom.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group:
2
+ - math_word_problems
3
+ task: gsm8k_custom
4
+ dataset_path: gsm8k
5
+ dataset_name: main
6
+ output_type: generate_until
7
+ training_split: train
8
+ fewshot_split: train
9
+ test_split: test
10
+ doc_to_text: "Question: {{question}}\nAnswer:"
11
+ doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
12
+ metric_list:
13
+ - metric: exact_match
14
+ aggregation: mean
15
+ higher_is_better: true
16
+ ignore_case: true
17
+ ignore_punctuation: false
18
+ regexes_to_ignore:
19
+ - ","
20
+ - "\\$"
21
+ - "(?s).*#### "
22
+ - "\\.$"
23
+ generation_kwargs:
24
+ until:
25
+ - "<|eot_id|>"
26
+ do_sample: false
27
+ temperature: 0.0
28
+ is_gsm8k: true
29
+ repeats: 1
30
+ num_fewshot: 5
31
+ filter_list:
32
+ # - name: "strict-match"
33
+ # filter:
34
+ # - function: "regex"
35
+ # regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
36
+ # - function: "take_first"
37
+ - name: "flexible-extract"
38
+ filter:
39
+ - function: "regex"
40
+ group_select: -1
41
+ regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
42
+ - function: "take_first"
43
+ metadata:
44
+ version: 3.0
src/display/utils.py CHANGED
@@ -75,6 +75,7 @@ class Tasks(Enum):
75
  # # XXX include me back at some point
76
  selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
77
  mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
 
78
 
79
 
80
  # These classes are for user facing column names,
 
75
  # # XXX include me back at some point
76
  selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
77
  mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
78
+ gsm8k = Task("gsm8k_custom", "em", "GSM8K") #GSM8K/EM (8-shot)
79
 
80
 
81
  # These classes are for user facing column names,
src/submission/check_validity.py CHANGED
@@ -130,7 +130,8 @@ def already_submitted_models(requested_models_dir: str) -> set[str]:
130
  continue
131
  with open(os.path.join(root, file), "r") as f:
132
  info = json.load(f)
133
- file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}_{info['inference_framework']}_{info['gpu_type']}")
 
134
 
135
  # Select organisation
136
  if info["model"].count("/") == 0 or "submitted_time" not in info:
 
130
  continue
131
  with open(os.path.join(root, file), "r") as f:
132
  info = json.load(f)
133
+ if not info["status"] == "FINISHED" and not info["status"] == "RUNNING":
134
+ file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}_{info['inference_framework']}_{info['gpu_type']}")
135
 
136
  # Select organisation
137
  if info["model"].count("/") == 0 or "submitted_time" not in info:
src/utils.py CHANGED
@@ -3,12 +3,48 @@ from huggingface_hub import snapshot_download
3
  import subprocess
4
  import re
5
  import os
 
6
 
7
  try:
8
  from src.display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
9
  except:
10
  print("local debug: from display.utils")
11
  from display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def my_snapshot_download(repo_id, revision, local_dir, repo_type, max_workers):
14
  for i in range(10):
@@ -52,11 +88,11 @@ def parse_nvidia_smi():
52
  print("Failed to query GPU indices.")
53
  return []
54
  gpu_indices = result.stdout.strip().split('\n')
55
- print(f"gpu_indices: {gpu_indices}")
56
  gpu_stats = []
57
 
58
  gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
59
- gpu_name_pattern = re.compile(r'NVIDIA\s+([\w\s]+?\d+GB)')
60
 
61
  gpu_name = ""
62
  for index in gpu_indices:
@@ -80,7 +116,7 @@ def parse_nvidia_smi():
80
 
81
  if len(gpu_info) >= 4:
82
  gpu_stats.append(gpu_info)
83
- print(f"gpu_stats: {gpu_stats}")
84
  gpu_name = f"{len(gpu_stats)}x{gpu_name}"
85
  gpu_stats_total = {
86
  GPU_TEMP: 0,
@@ -131,5 +167,70 @@ def analyze_gpu_stats(stats_list):
131
 
132
  return avg_stats
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  if __name__ == "__main__":
135
  print(analyze_gpu_stats(parse_nvidia_smi()))
 
3
  import subprocess
4
  import re
5
  import os
6
+ import GPUtil
7
 
8
  try:
9
  from src.display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
10
  except:
11
  print("local debug: from display.utils")
12
  from display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
13
+
14
+ MEM_BW_DICT ={
15
+ "NVIDIA-A100-PCIe-80GB": 1935,
16
+ "NVIDIA-A100-SXM-80GB": 2039,
17
+ "NVIDIA-H100-PCIe-80GB": 2039,
18
+ "NVIDIA-RTX-A5000-24GB": 768
19
+ }
20
+
21
+ PEAK_FLOPS_DICT = {
22
+ "float32":{
23
+ "NVIDIA-A100-PCIe-80GB": 312e12,
24
+ "NVIDIA-A100-SXM-80GB": 312e12,
25
+ "NVIDIA-H100-PCIe-80GB": 756e12,
26
+ "NVIDIA-RTX-A5000-24GB": 222.2e12
27
+ },
28
+ "float16":{
29
+ "NVIDIA-A100-PCIe-80GB": 624e12,
30
+ "NVIDIA-A100-SXM-80GB": 624e12,
31
+ "NVIDIA-H100-PCIe-80GB": 1513e12,
32
+ "NVIDIA-RTX-A5000-24GB": 444.4e12
33
+ },
34
+ "8bit":{
35
+ "NVIDIA-A100-PCIe-80GB": 1248e12,
36
+ "NVIDIA-A100-SXM-80GB": 1248e12,
37
+ "NVIDIA-H100-PCIe-80GB": 3026e12,
38
+ "NVIDIA-RTX-A5000-24GB": 889e12
39
+ },
40
+ "4bit": {
41
+ "NVIDIA-A100-PCIe-80GB": 2496e12,
42
+ "NVIDIA-A100-SXM-80GB": 2496e12,
43
+ "NVIDIA-H100-PCIe-80GB": 6052e12,
44
+ "NVIDIA-RTX-A5000-24GB": 1778e12
45
+ }
46
+
47
+ }
48
 
49
  def my_snapshot_download(repo_id, revision, local_dir, repo_type, max_workers):
50
  for i in range(10):
 
88
  print("Failed to query GPU indices.")
89
  return []
90
  gpu_indices = result.stdout.strip().split('\n')
91
+ # print(f"gpu_indices: {gpu_indices}")
92
  gpu_stats = []
93
 
94
  gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
95
+ gpu_name_pattern = re.compile(r'NVIDIA\s+([\w\s]+\d+(?:\s*GB)?)')
96
 
97
  gpu_name = ""
98
  for index in gpu_indices:
 
116
 
117
  if len(gpu_info) >= 4:
118
  gpu_stats.append(gpu_info)
119
+ # print(f"gpu_stats: {gpu_stats}")
120
  gpu_name = f"{len(gpu_stats)}x{gpu_name}"
121
  gpu_stats_total = {
122
  GPU_TEMP: 0,
 
167
 
168
  return avg_stats
169
 
170
+ def get_gpu_number():
171
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
172
+ if visible_devices is not None:
173
+ gpu_indices = visible_devices.split(',')
174
+ else:
175
+ # Query all GPU indices if CUDA_VISIBLE_DEVICES is not set
176
+ result = subprocess.run(['nvidia-smi', '--query-gpu=index', '--format=csv,noheader'], capture_output=True, text=True)
177
+ if result.returncode != 0:
178
+ print("Failed to query GPU indices.")
179
+ return []
180
+ gpu_indices = result.stdout.strip().split('\n')
181
+ # print(f"gpu_indices: {gpu_indices}")
182
+ gpu_stats = []
183
+
184
+ gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
185
+
186
+ for index in gpu_indices:
187
+ result = subprocess.run(['nvidia-smi', '-i', index], capture_output=True, text=True)
188
+ output = result.stdout.strip()
189
+ lines = output.split("\n")
190
+ for line in lines:
191
+ match = gpu_info_pattern.search(line)
192
+ gpu_info = {}
193
+ if match:
194
+ temp, power_usage, mem_usage, gpu_util = map(int, match.groups())
195
+ gpu_info.update({
196
+ GPU_TEMP: temp,
197
+ GPU_Power: power_usage,
198
+ GPU_Mem: round(mem_usage / 1024, 2),
199
+ GPU_Util: gpu_util
200
+ })
201
+
202
+ if len(gpu_info) >= 4:
203
+ gpu_stats.append(gpu_info)
204
+
205
+ return len(gpu_stats)
206
+
207
+ def get_gpu_details():
208
+ gpus = GPUtil.getGPUs()
209
+ gpu = gpus[0]
210
+ name = gpu.name.replace(" ", "-")
211
+ # Convert memory from MB to GB and round to nearest whole number
212
+ memory_gb = round(gpu.memoryTotal / 1024)
213
+ memory = f"{memory_gb}GB"
214
+ formatted_name = f"{name}-{memory}"
215
+ return formatted_name
216
+
217
+ def get_peak_bw(gpu_name):
218
+ return MEM_BW_DICT[gpu_name]
219
+
220
+ def get_peak_flops(gpu_name, precision):
221
+ return PEAK_FLOPS_DICT[precision][gpu_name]
222
+
223
+ def transfer_precision2bytes(precision):
224
+ if precision == "float32":
225
+ return 4
226
+ elif precision == "float16":
227
+ return 2
228
+ elif precision == "8bit":
229
+ return 1
230
+ elif precision == "4bit":
231
+ return 0.5
232
+ else:
233
+ raise ValueError(f"Unsupported precision: {precision}")
234
+
235
  if __name__ == "__main__":
236
  print(analyze_gpu_stats(parse_nvidia_smi()))