Spaces:
Runtime error
Runtime error
File size: 2,028 Bytes
0533d8a 2441f02 0533d8a 239a905 0533d8a 5e71d7d dea7b06 0533d8a dea7b06 5e71d7d 0533d8a e041ed4 5e71d7d 0533d8a 9612937 2441f02 0533d8a 9612937 5e71d7d 0533d8a dea7b06 5e71d7d e041ed4 5e71d7d 60d7aa9 e041ed4 5e71d7d 0533d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
###
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.05, 2, 100)
overheads = []
for kn in np.linspace(0.2, 2, 100):
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig = plt.figure()
plt.plot(kns, overheads, color="black")
plt.scatter([kn], [compute_overhead(kn, kd)*100], marker="D", markerfacecolor="red", markeredgecolor="black", label="You are here!")
plt.xlabel("Fraction of compute optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}
## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B
Optimal datset size (tokens):\t {D_opt/Bn:.2f}
## Your setting trade-off:
Training compute overhead (%):\t {100*compute_overhead(kn, kd):.2f}
Inference cost fraction (%):\t {kn*100:.2f}"""
return text, fig
with gr.Blocks() as demo:
N = gr.Number(value=1, label="Model size (in B parameters)")
D = gr.Number(value=100, label="Dataset size (in B tokens")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch() |