shigeki Ishida commited on
Commit
860d490
•
1 Parent(s): 0ef9174
Files changed (2) hide show
  1. app.py +2 -2
  2. src/display/utils.py +7 -8
app.py CHANGED
@@ -572,7 +572,7 @@ with gr.Blocks() as demo_submission:
572
  revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main")
573
  model_type = gr.Dropdown(
574
  label="Model type",
575
- choices=[t.to_str(" : ") for t in ModelType if t != ModelType.Unknown],
576
  multiselect=False,
577
  value=None,
578
  )
@@ -580,7 +580,7 @@ with gr.Blocks() as demo_submission:
580
  with gr.Column():
581
  precision = gr.Dropdown(
582
  label="Precision",
583
- choices=[i.value.name for i in Precision if i != Precision.Unknown],
584
  multiselect=False,
585
  value="float16",
586
  )
 
572
  revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main")
573
  model_type = gr.Dropdown(
574
  label="Model type",
575
+ choices=[t.to_str(" : ") for t in ModelType],
576
  multiselect=False,
577
  value=None,
578
  )
 
580
  with gr.Column():
581
  precision = gr.Dropdown(
582
  label="Precision",
583
+ choices=[i.value.name for i in Precision],
584
  multiselect=False,
585
  value="float16",
586
  )
src/display/utils.py CHANGED
@@ -105,7 +105,6 @@ class ModelType(Enum):
105
  FT = ModelDetails(name="fine-tuned", symbol="🔶")
106
  IFT = ModelDetails(name="instruction-tuned", symbol="â­•")
107
  RL = ModelDetails(name="RL-tuned", symbol="🟦")
108
- Unknown = ModelDetails(name="", symbol="?")
109
 
110
  def to_str(self, separator=" "):
111
  return f"{self.value.symbol}{separator}{self.value.name}"
@@ -120,7 +119,7 @@ class ModelType(Enum):
120
  return ModelType.RL
121
  if "instruction-tuned" in type or "â­•" in type:
122
  return ModelType.IFT
123
- return ModelType.Unknown
124
 
125
 
126
  class WeightType(Enum):
@@ -132,14 +131,14 @@ class WeightType(Enum):
132
  class Precision(Enum):
133
  float16 = ModelDetails("float16")
134
  bfloat16 = ModelDetails("bfloat16")
135
- Unknown = ModelDetails("?")
136
 
137
- def from_str(precision):
 
138
  if precision in ["torch.float16", "float16"]:
139
  return Precision.float16
140
  if precision in ["torch.bfloat16", "bfloat16"]:
141
  return Precision.bfloat16
142
- return Precision.Unknown
143
 
144
 
145
  class AddSpecialTokens(Enum):
@@ -150,14 +149,14 @@ class AddSpecialTokens(Enum):
150
  class NumFewShots(Enum):
151
  shots_0 = ModelDetails("0")
152
  shots_4 = ModelDetails("4")
153
- Unknown = ModelDetails("?")
154
 
155
- def from_str(shots):
 
156
  if shots == "0":
157
  return NumFewShots.shots_0
158
  if shots == "4":
159
  return NumFewShots.shots_4
160
- return NumFewShots.Unknown
161
 
162
 
163
  class LLMJpEvalVersion(Enum):
 
105
  FT = ModelDetails(name="fine-tuned", symbol="🔶")
106
  IFT = ModelDetails(name="instruction-tuned", symbol="â­•")
107
  RL = ModelDetails(name="RL-tuned", symbol="🟦")
 
108
 
109
  def to_str(self, separator=" "):
110
  return f"{self.value.symbol}{separator}{self.value.name}"
 
119
  return ModelType.RL
120
  if "instruction-tuned" in type or "â­•" in type:
121
  return ModelType.IFT
122
+ raise ValueError(f"Unsupported model type: {type}")
123
 
124
 
125
  class WeightType(Enum):
 
131
  class Precision(Enum):
132
  float16 = ModelDetails("float16")
133
  bfloat16 = ModelDetails("bfloat16")
 
134
 
135
+ @staticmethod
136
+ def from_str(precision: str) -> "Precision":
137
  if precision in ["torch.float16", "float16"]:
138
  return Precision.float16
139
  if precision in ["torch.bfloat16", "bfloat16"]:
140
  return Precision.bfloat16
141
+ raise ValueError(f"Unsupported precision type: {precision}")
142
 
143
 
144
  class AddSpecialTokens(Enum):
 
149
  class NumFewShots(Enum):
150
  shots_0 = ModelDetails("0")
151
  shots_4 = ModelDetails("4")
 
152
 
153
+ @staticmethod
154
+ def from_str(shots: str) -> "NumFewShots":
155
  if shots == "0":
156
  return NumFewShots.shots_0
157
  if shots == "4":
158
  return NumFewShots.shots_4
159
+ raise ValueError(f"Unsupported number of shots: {shots}. Must be either '0' or '4'")
160
 
161
 
162
  class LLMJpEvalVersion(Enum):