lvwerra HF staff commited on
Commit
5e71d7d
1 Parent(s): 068a2a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -42,15 +42,16 @@ for kn in np.linspace(0.2, 2, 100):
42
  overheads.append(compute_overhead(kn, kd)*100)
43
 
44
  def plot_curve(kn, kd):
 
45
  plt.plot(kns, overheads)
46
  plt.scatter([kn], [kd])
47
  plt.xlabel("Fraction of compute optimal model size")
48
  plt.ylabel("Compute overhead (%)")
49
-
50
 
51
 
52
  def compute(N, D):
53
- print(N, D)
54
  C = to_flops(N * Bn, D * Bn)
55
  N_opt = n_opt(C)
56
  D_opt = d_opt(C)
@@ -58,19 +59,21 @@ def compute(N, D):
58
  kn = N/N_opt
59
  kd = compute_kd(kn)
60
 
61
- plot_curve(kn, kd)
 
 
62
 
63
  text = f"""Compute budget (TFLOPs): {C:.2E}\n\nTraining compute overhead (%): {100*compute_overhead(kn, kd):.2f}\n\nInference cost fraction (%): {kn*100:.2f}"""
64
- return text
65
 
66
  with gr.Blocks() as demo:
67
  N = gr.Number(value=1, label="Model size (in B parameters)")
68
  D = gr.Number(value=100, label="Dataset size (in B tokens")
69
  button = gr.Button("Compute!")
70
 
71
- gr.Plot(value=plt)
72
  md = gr.Markdown("")
73
 
74
- button.click(fn=compute, inputs=[N, D], outputs=[md])
75
 
76
  demo.launch()
 
42
  overheads.append(compute_overhead(kn, kd)*100)
43
 
44
  def plot_curve(kn, kd):
45
+ fig = plt.figure()
46
  plt.plot(kns, overheads)
47
  plt.scatter([kn], [kd])
48
  plt.xlabel("Fraction of compute optimal model size")
49
  plt.ylabel("Compute overhead (%)")
50
+ return fig
51
 
52
 
53
  def compute(N, D):
54
+
55
  C = to_flops(N * Bn, D * Bn)
56
  N_opt = n_opt(C)
57
  D_opt = d_opt(C)
 
59
  kn = N/N_opt
60
  kd = compute_kd(kn)
61
 
62
+ print(N, D, N_opt, D_opt, kn, kd)
63
+
64
+ fig = plot_curve(kn, kd)
65
 
66
  text = f"""Compute budget (TFLOPs): {C:.2E}\n\nTraining compute overhead (%): {100*compute_overhead(kn, kd):.2f}\n\nInference cost fraction (%): {kn*100:.2f}"""
67
+ return text, fig
68
 
69
  with gr.Blocks() as demo:
70
  N = gr.Number(value=1, label="Model size (in B parameters)")
71
  D = gr.Number(value=100, label="Dataset size (in B tokens")
72
  button = gr.Button("Compute!")
73
 
74
+ plot = gr.Plot(value=plt)
75
  md = gr.Markdown("")
76
 
77
+ button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
78
 
79
  demo.launch()