lvwerra HF staff commited on
Commit
0533d8a
1 Parent(s): ff94d19

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ ### CHINCHILLA PARAMS:
6
+ E = 1.62
7
+ A = 406.4
8
+ B = 410.7
9
+ alpha = 0.336
10
+ beta = 0.283
11
+
12
+ Bn = 10**9
13
+
14
+ G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
15
+ ###
16
+
17
+ def to_flops(N, D):
18
+ return 6 * N * D
19
+
20
+ def n_opt(C):
21
+ return G * ((C/6) ** (beta / (alpha+beta)))
22
+
23
+ def d_opt(C):
24
+ return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
25
+
26
+ def get_kd(kn):
27
+ frac = (A/B)*(G**(-alpha-beta))
28
+ kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
29
+ return kd
30
+
31
+ def compute_overhead(kn, kd):
32
+ return kn*kd - 1
33
+
34
+ ### PRECOMPUTE CURVE:
35
+ kn_min = 0.2
36
+ kn_max = 2
37
+
38
+ kns = np.linspace(0.05, 2, 100)
39
+ overheads = []
40
+ for kn in np.linspace(0.2, 2, 100):
41
+ kd = get_kd(kn)
42
+ overheads.append(compute_overhead(kn, kd)*100)
43
+
44
+ def plot_curve(kn, kd):
45
+ plt.plot(kns, overheads)
46
+ plt.scatter([kn], [kd])
47
+ plt.xlabel("Fraction of compute optimal model size")
48
+ plt.ylabel("Compute overhead (%)")
49
+
50
+ with gr.Blocks() as demo:
51
+ N = gr.number(value=1, label="Model size (in B parameters)")
52
+ D = gr.number(value=100, label="Dataset size (in B tokens")
53
+
54
+ C = to_flops(N * Bn, D * Bn)
55
+ N_opt = n_opt(C)
56
+ D_opt = d_opt(C)
57
+
58
+ kn = N/N_opt
59
+
60
+ plot_curve(kn, 100*overhead(kn, get_kd(kn)))
61
+
62
+ gr.Plot(value=plt)
63
+ gr.Markdown(f"""Compute budget (TFLOPs): {C:.2E}\nTraining compute overhead (%): {100*overhead(kn, get_kd(kn)).2f}\nInference cost fraction (%): {kn*100:.2f}""")
64
+
65
+ demo.launch()