t0-0 commited on
Commit
559d198
1 Parent(s): bd95334

Remove 'auto' from Enum and add handling for submissions with 'auto'.

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. src/display/utils.py +0 -3
  3. src/submission/submit.py +25 -7
app.py CHANGED
@@ -579,7 +579,7 @@ with gr.Blocks() as demo_submission:
579
  with gr.Column():
580
  precision = gr.Dropdown(
581
  label="Precision",
582
- choices=[i.value.name for i in Precision],
583
  multiselect=False,
584
  value="auto",
585
  )
 
579
  with gr.Column():
580
  precision = gr.Dropdown(
581
  label="Precision",
582
+ choices=[i.value.name for i in Precision] + ["auto"],
583
  multiselect=False,
584
  value="auto",
585
  )
src/display/utils.py CHANGED
@@ -129,15 +129,12 @@ class WeightType(Enum):
129
 
130
 
131
  class Precision(Enum):
132
- auto = ModelDetails("auto")
133
  float16 = ModelDetails("float16")
134
  bfloat16 = ModelDetails("bfloat16")
135
  float32 = ModelDetails("float32")
136
 
137
  @staticmethod
138
  def from_str(precision: str) -> "Precision":
139
- if precision == "auto":
140
- return Precision.auto
141
  if precision == "float16":
142
  return Precision.float16
143
  if precision == "bfloat16":
 
129
 
130
 
131
  class Precision(Enum):
 
132
  float16 = ModelDetails("float16")
133
  bfloat16 = ModelDetails("bfloat16")
134
  float32 = ModelDetails("float32")
135
 
136
  @staticmethod
137
  def from_str(precision: str) -> "Precision":
 
 
138
  if precision == "float16":
139
  return Precision.float16
140
  if precision == "bfloat16":
src/submission/submit.py CHANGED
@@ -1,6 +1,8 @@
1
  import json
2
  from datetime import datetime, timezone
3
 
 
 
4
  from src.display.formatting import styled_error, styled_message, styled_warning
5
  from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
6
  from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
@@ -25,6 +27,29 @@ def add_new_eval(
25
 
26
  revision = revision or "main"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  model_data = EvalQueuedModel(
29
  model=model_id,
30
  revision=revision,
@@ -47,13 +72,6 @@ def add_new_eval(
47
  if model_type is None or model_type == "":
48
  return styled_error("Please select a model type.")
49
 
50
- # Is the model on the hub?
51
- model_on_hub, error, _ = is_model_on_hub(
52
- model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
53
- )
54
- if not model_on_hub:
55
- return styled_error(f'Model "{model_id}" {error}')
56
-
57
  # Is the model info correctly filled?
58
  try:
59
  model_info = API.model_info(repo_id=model_id, revision=revision)
 
1
  import json
2
  from datetime import datetime, timezone
3
 
4
+ import torch
5
+
6
  from src.display.formatting import styled_error, styled_message, styled_warning
7
  from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
8
  from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
 
27
 
28
  revision = revision or "main"
29
 
30
+ # Is the model on the hub?
31
+ model_on_hub, error, config = is_model_on_hub(
32
+ model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
33
+ )
34
+ if not model_on_hub:
35
+ return styled_error(f'Model "{model_id}" {error}')
36
+ if precision == "auto":
37
+ dtype = ""
38
+ if hasattr(config, "dtype"):
39
+ dtype = config.dtype
40
+ elif hasattr(config, "torch_dtype"):
41
+ dtype = config.torch_dtype
42
+ if dtype == torch.float16:
43
+ precision = "float16"
44
+ elif dtype in torch.bfloat16:
45
+ precision = "bfloat16"
46
+ elif dtype in torch.float32:
47
+ precision = "float32"
48
+ else:
49
+ return styled_error(
50
+ "Unable to retrieve a valid dtype from config.json. Please select an appropriate one from fp16/fp32/bf16 and resubmit."
51
+ )
52
+
53
  model_data = EvalQueuedModel(
54
  model=model_id,
55
  revision=revision,
 
72
  if model_type is None or model_type == "":
73
  return styled_error("Please select a model type.")
74
 
 
 
 
 
 
 
 
75
  # Is the model info correctly filled?
76
  try:
77
  model_info = API.model_info(repo_id=model_id, revision=revision)