add float32 as data type
Browse files- app.py +1 -1
- helpers.py +2 -0
app.py
CHANGED
@@ -180,7 +180,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
180 |
gr.Markdown("### Device specific settings")
|
181 |
with gr.Row():
|
182 |
in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
183 |
-
in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
184 |
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
185 |
in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
186 |
|
|
|
180 |
gr.Markdown("### Device specific settings")
|
181 |
with gr.Row():
|
182 |
in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
183 |
+
in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16", "float32"], info="`bfloat16` is not supported on MPS devices right now; `float16` may also not be supported on all devices, Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
184 |
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
185 |
in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
186 |
|
helpers.py
CHANGED
@@ -29,6 +29,8 @@ def get_data_type(str_data_type):
|
|
29 |
|
30 |
if str_data_type == "bfloat16":
|
31 |
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
|
|
|
|
|
32 |
else:
|
33 |
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
|
34 |
|
|
|
29 |
|
30 |
if str_data_type == "bfloat16":
|
31 |
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
|
32 |
+
if str_data_type == "float32":
|
33 |
+
return torch.float32 # BFloat16 is not supported on MPS as of 01/2024
|
34 |
else:
|
35 |
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
|
36 |
|