t0-0 commited on
Commit
bf7bdee
1 Parent(s): 4c9c17b

Add auto/fp32 option and set auto as the default for submission

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. src/display/utils.py +6 -0
app.py CHANGED
@@ -582,7 +582,7 @@ with gr.Blocks() as demo_submission:
582
  label="Precision",
583
  choices=[i.value.name for i in Precision],
584
  multiselect=False,
585
- value="float16",
586
  )
587
  add_special_tokens = gr.Dropdown(
588
  label="AddSpecialTokens",
 
582
  label="Precision",
583
  choices=[i.value.name for i in Precision],
584
  multiselect=False,
585
+ value="auto",
586
  )
587
  add_special_tokens = gr.Dropdown(
588
  label="AddSpecialTokens",
src/display/utils.py CHANGED
@@ -129,13 +129,19 @@ class WeightType(Enum):
129
 
130
 
131
  class Precision(Enum):
 
132
  float16 = ModelDetails("float16")
 
133
  bfloat16 = ModelDetails("bfloat16")
134
 
135
  @staticmethod
136
  def from_str(precision: str) -> "Precision":
 
 
137
  if precision in ["torch.float16", "float16"]:
138
  return Precision.float16
 
 
139
  if precision in ["torch.bfloat16", "bfloat16"]:
140
  return Precision.bfloat16
141
  raise ValueError(f"Unsupported precision type: {precision}")
 
129
 
130
 
131
  class Precision(Enum):
132
+ auto = ModelDetails("auto")
133
  float16 = ModelDetails("float16")
134
+ float32 = ModelDetails("float32")
135
  bfloat16 = ModelDetails("bfloat16")
136
 
137
  @staticmethod
138
  def from_str(precision: str) -> "Precision":
139
+ if precision == "auto":
140
+ return Precision.auto
141
  if precision in ["torch.float16", "float16"]:
142
  return Precision.float16
143
+ if precision in ["torch.float32", "float32"]:
144
+ return Precision.float32
145
  if precision in ["torch.bfloat16", "bfloat16"]:
146
  return Precision.bfloat16
147
  raise ValueError(f"Unsupported precision type: {precision}")