Update app.py
Browse files
app.py
CHANGED
@@ -77,8 +77,10 @@ def plot_pens(tflpos_card, utilization, num_gps, training_days):
|
|
77 |
plt.axvline(ns[best], color='red')
|
78 |
plt.xlabel('model size')
|
79 |
plt.ylabel('loss')
|
|
|
|
|
80 |
|
81 |
-
return
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
@@ -86,12 +88,12 @@ if __name__ == "__main__":
|
|
86 |
fn=plot_pens,
|
87 |
inputs=[
|
88 |
gr.Textbox(label="TFLOP/s pre Card",value="40"),
|
89 |
-
gr.Slider(label="
|
90 |
-
gr.Textbox(label="Number of cards"
|
91 |
-
gr.Textbox(label="Training Days"
|
92 |
],
|
93 |
outputs=[
|
94 |
-
gr.
|
95 |
gr.Label(label="Total Compute Budget"),
|
96 |
gr.Label(label="Estimated Final Loss"),
|
97 |
gr.Label(label="Optimal Model Size"),
|
@@ -100,5 +102,5 @@ if __name__ == "__main__":
|
|
100 |
title="Compute-Optimal Model Estimator",
|
101 |
description=description,
|
102 |
article=article,
|
103 |
-
live=
|
104 |
).launch()
|
|
|
77 |
plt.axvline(ns[best], color='red')
|
78 |
plt.xlabel('model size')
|
79 |
plt.ylabel('loss')
|
80 |
+
fig.savefig("/tmp/tmp.jpg")
|
81 |
+
plt.close()
|
82 |
|
83 |
+
return "/tmp/tmp.jpg", c, round(losses[best], 3), best_model_size, best_dataset_size
|
84 |
|
85 |
|
86 |
if __name__ == "__main__":
|
|
|
88 |
fn=plot_pens,
|
89 |
inputs=[
|
90 |
gr.Textbox(label="TFLOP/s pre Card",value="40"),
|
91 |
+
gr.Slider(label="GPU Utilization", minimum=0, maximum=1, step=0.01,value=0.25),
|
92 |
+
gr.Textbox(label="Number of cards"),
|
93 |
+
gr.Textbox(label="Training Days")
|
94 |
],
|
95 |
outputs=[
|
96 |
+
gr.Image(label="Estimated Loss"),
|
97 |
gr.Label(label="Total Compute Budget"),
|
98 |
gr.Label(label="Estimated Final Loss"),
|
99 |
gr.Label(label="Optimal Model Size"),
|
|
|
102 |
title="Compute-Optimal Model Estimator",
|
103 |
description=description,
|
104 |
article=article,
|
105 |
+
live=True
|
106 |
).launch()
|