AppleSwing commited on
Commit
dbe8db4
1 Parent(s): 6e99f9d

Apply GPU type verification on backend debug mode

Browse files
Files changed (1) hide show
  1. backend-cli.py +6 -1
backend-cli.py CHANGED
@@ -448,7 +448,8 @@ def get_args():
448
  parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
449
  parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
450
  parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
451
- parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB", help="GPU type")
 
452
  return parser.parse_args()
453
 
454
 
@@ -480,6 +481,10 @@ if __name__ == "__main__":
480
  inference_framework=args.inference_framework, # Use inference framework from arguments
481
  gpu_type=args.gpu_type
482
  )
 
 
 
 
483
  results = process_evaluation(task, eval_request, limit=args.limit)
484
  except Exception as e:
485
  print(f"debug running error: {e}")
 
448
  parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
449
  parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
450
  parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
451
+ parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB",
452
+ help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24G; NVIDIA-H100-PCIe-80G")
453
  return parser.parse_args()
454
 
455
 
 
481
  inference_framework=args.inference_framework, # Use inference framework from arguments
482
  gpu_type=args.gpu_type
483
  )
484
+ curr_gpu_type = get_gpu_details()
485
+ if eval_request.gpu_type != curr_gpu_type:
486
+ print(f"GPU type mismatch: {eval_request.gpu_type} vs {curr_gpu_type}")
487
+ raise Exception("GPU type mismatch")
488
  results = process_evaluation(task, eval_request, limit=args.limit)
489
  except Exception as e:
490
  print(f"debug running error: {e}")