Spaces:
Sleeping
Sleeping
File size: 2,330 Bytes
0533d8a ed4871a 0533d8a 2441f02 0533d8a 4cf381c 0533d8a 4795aa0 239a905 0533d8a 5e71d7d dea7b06 afadb31 0533d8a dea7b06 5e71d7d 0533d8a e041ed4 5e71d7d 0533d8a 9612937 2441f02 0533d8a 9612937 5e71d7d 0533d8a dea7b06 afadb31 dea7b06 4795aa0 afadb31 4795aa0 5e71d7d e041ed4 ed4871a 4795aa0 e041ed4 5e71d7d 60d7aa9 e041ed4 5e71d7d 0533d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig = plt.figure()
plt.plot(kns, overheads, color="black")
plt.scatter([kn], [compute_overhead(kn, kd)*100], marker="x", c="red", label="You are here!")
plt.xlabel("Fraction of compute optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}
## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B
Optimal datset size (tokens):\t {D_opt/Bn:.2f}
## Your setting trade-off:
Training compute overhead:\t {100*compute_overhead(kn, kd):.2f}%
Inference cost fraction:\t {kn*100:.2f}%"""
return text, fig
with gr.Blocks() as demo:
gr.Markdown("# Harm's law\
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
")
N = gr.Number(value=1, label="Model size (in B parameters):")
D = gr.Number(value=100, label="Dataset size (in B tokens):")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch() |