AppleSwing commited on
Commit
b20ad66
1 Parent(s): c7bf25a

Add bfloat16

Browse files
Files changed (1) hide show
  1. src/utils.py +1 -1
src/utils.py CHANGED
@@ -223,7 +223,7 @@ def get_peak_flops(gpu_name, precision):
223
  def transfer_precision2bytes(precision):
224
  if precision == "float32":
225
  return 4
226
- elif precision == "float16":
227
  return 2
228
  elif precision == "8bit":
229
  return 1
 
223
  def transfer_precision2bytes(precision):
224
  if precision == "float32":
225
  return 4
226
+ elif precision in ["float16", "bfloat16"]:
227
  return 2
228
  elif precision == "8bit":
229
  return 1