n42 commited on
Commit
56914a9
·
1 Parent(s): 5a9e237

add float32 as data type

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. helpers.py +2 -0
app.py CHANGED
@@ -180,7 +180,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
180
  gr.Markdown("### Device specific settings")
181
  with gr.Row():
182
  in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
183
- in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
184
  in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
185
  in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
186
 
 
180
  gr.Markdown("### Device specific settings")
181
  with gr.Row():
182
  in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
183
+ in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16", "float32"], info="`bfloat16` is not supported on MPS devices right now; `float16` may also not be supported on all devices, Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
184
  in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
185
  in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
186
 
helpers.py CHANGED
@@ -29,6 +29,8 @@ def get_data_type(str_data_type):
29
 
30
  if str_data_type == "bfloat16":
31
  return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
 
 
32
  else:
33
  return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
34
 
 
29
 
30
  if str_data_type == "bfloat16":
31
  return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
32
+ if str_data_type == "float32":
33
+ return torch.float32 # BFloat16 is not supported on MPS as of 01/2024
34
  else:
35
  return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
36