AppleSwing
commited on
Commit
•
b20ad66
1
Parent(s):
c7bf25a
Add bfloat16
Browse files- src/utils.py +1 -1
src/utils.py
CHANGED
@@ -223,7 +223,7 @@ def get_peak_flops(gpu_name, precision):
|
|
223 |
def transfer_precision2bytes(precision):
|
224 |
if precision == "float32":
|
225 |
return 4
|
226 |
-
elif precision
|
227 |
return 2
|
228 |
elif precision == "8bit":
|
229 |
return 1
|
|
|
223 |
def transfer_precision2bytes(precision):
|
224 |
if precision == "float32":
|
225 |
return 4
|
226 |
+
elif precision in ["float16", "bfloat16"]:
|
227 |
return 2
|
228 |
elif precision == "8bit":
|
229 |
return 1
|