Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
t0-0
commited on
Commit
•
559d198
1
Parent(s):
bd95334
Remove 'auto' from Enum and add handling for submissions with 'auto'.
Browse files- app.py +1 -1
- src/display/utils.py +0 -3
- src/submission/submit.py +25 -7
app.py
CHANGED
@@ -579,7 +579,7 @@ with gr.Blocks() as demo_submission:
|
|
579 |
with gr.Column():
|
580 |
precision = gr.Dropdown(
|
581 |
label="Precision",
|
582 |
-
choices=[i.value.name for i in Precision],
|
583 |
multiselect=False,
|
584 |
value="auto",
|
585 |
)
|
|
|
579 |
with gr.Column():
|
580 |
precision = gr.Dropdown(
|
581 |
label="Precision",
|
582 |
+
choices=[i.value.name for i in Precision] + ["auto"],
|
583 |
multiselect=False,
|
584 |
value="auto",
|
585 |
)
|
src/display/utils.py
CHANGED
@@ -129,15 +129,12 @@ class WeightType(Enum):
|
|
129 |
|
130 |
|
131 |
class Precision(Enum):
|
132 |
-
auto = ModelDetails("auto")
|
133 |
float16 = ModelDetails("float16")
|
134 |
bfloat16 = ModelDetails("bfloat16")
|
135 |
float32 = ModelDetails("float32")
|
136 |
|
137 |
@staticmethod
|
138 |
def from_str(precision: str) -> "Precision":
|
139 |
-
if precision == "auto":
|
140 |
-
return Precision.auto
|
141 |
if precision == "float16":
|
142 |
return Precision.float16
|
143 |
if precision == "bfloat16":
|
|
|
129 |
|
130 |
|
131 |
class Precision(Enum):
|
|
|
132 |
float16 = ModelDetails("float16")
|
133 |
bfloat16 = ModelDetails("bfloat16")
|
134 |
float32 = ModelDetails("float32")
|
135 |
|
136 |
@staticmethod
|
137 |
def from_str(precision: str) -> "Precision":
|
|
|
|
|
138 |
if precision == "float16":
|
139 |
return Precision.float16
|
140 |
if precision == "bfloat16":
|
src/submission/submit.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import json
|
2 |
from datetime import datetime, timezone
|
3 |
|
|
|
|
|
4 |
from src.display.formatting import styled_error, styled_message, styled_warning
|
5 |
from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
|
6 |
from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
|
@@ -25,6 +27,29 @@ def add_new_eval(
|
|
25 |
|
26 |
revision = revision or "main"
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model_data = EvalQueuedModel(
|
29 |
model=model_id,
|
30 |
revision=revision,
|
@@ -47,13 +72,6 @@ def add_new_eval(
|
|
47 |
if model_type is None or model_type == "":
|
48 |
return styled_error("Please select a model type.")
|
49 |
|
50 |
-
# Is the model on the hub?
|
51 |
-
model_on_hub, error, _ = is_model_on_hub(
|
52 |
-
model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
|
53 |
-
)
|
54 |
-
if not model_on_hub:
|
55 |
-
return styled_error(f'Model "{model_id}" {error}')
|
56 |
-
|
57 |
# Is the model info correctly filled?
|
58 |
try:
|
59 |
model_info = API.model_info(repo_id=model_id, revision=revision)
|
|
|
1 |
import json
|
2 |
from datetime import datetime, timezone
|
3 |
|
4 |
+
import torch
|
5 |
+
|
6 |
from src.display.formatting import styled_error, styled_message, styled_warning
|
7 |
from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
|
8 |
from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
|
|
|
27 |
|
28 |
revision = revision or "main"
|
29 |
|
30 |
+
# Is the model on the hub?
|
31 |
+
model_on_hub, error, config = is_model_on_hub(
|
32 |
+
model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
|
33 |
+
)
|
34 |
+
if not model_on_hub:
|
35 |
+
return styled_error(f'Model "{model_id}" {error}')
|
36 |
+
if precision == "auto":
|
37 |
+
dtype = ""
|
38 |
+
if hasattr(config, "dtype"):
|
39 |
+
dtype = config.dtype
|
40 |
+
elif hasattr(config, "torch_dtype"):
|
41 |
+
dtype = config.torch_dtype
|
42 |
+
if dtype == torch.float16:
|
43 |
+
precision = "float16"
|
44 |
+
elif dtype in torch.bfloat16:
|
45 |
+
precision = "bfloat16"
|
46 |
+
elif dtype in torch.float32:
|
47 |
+
precision = "float32"
|
48 |
+
else:
|
49 |
+
return styled_error(
|
50 |
+
"Unable to retrieve a valid dtype from config.json. Please select an appropriate one from fp16/fp32/bf16 and resubmit."
|
51 |
+
)
|
52 |
+
|
53 |
model_data = EvalQueuedModel(
|
54 |
model=model_id,
|
55 |
revision=revision,
|
|
|
72 |
if model_type is None or model_type == "":
|
73 |
return styled_error("Please select a model type.")
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
# Is the model info correctly filled?
|
76 |
try:
|
77 |
model_info = API.model_info(repo_id=model_id, revision=revision)
|