Support bfloat16

#3
by nouamanetazi HF staff - opened
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -85,7 +85,7 @@ def calculate_memory(model_name:str, library:str, options:list, access_token:str
85
  for dtype in options:
86
  dtype_total_size = total_size
87
  dtype_largest_layer = largest_layer[0]
88
- if dtype in ("float16", "fp16"):
89
  dtype_total_size /= 2
90
  dtype_largest_layer /= 2
91
  elif dtype == "int8":
@@ -149,7 +149,7 @@ with gr.Blocks() as demo:
149
  with gr.Row():
150
  library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto")
151
  options = gr.CheckboxGroup(
152
- ["float32", "float16", "int8", "int4"],
153
  value="float32",
154
  label="Model Precision",
155
  )
 
85
  for dtype in options:
86
  dtype_total_size = total_size
87
  dtype_largest_layer = largest_layer[0]
88
+ if dtype in ("float16", "fp16", "bfloat16", "bf16"):
89
  dtype_total_size /= 2
90
  dtype_largest_layer /= 2
91
  elif dtype == "int8":
 
149
  with gr.Row():
150
  library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto")
151
  options = gr.CheckboxGroup(
152
+ ["float32", "float16/bfloat16", "int8", "int4"],
153
  value="float32",
154
  label="Model Precision",
155
  )