harms-law / app.py
lvwerra's picture
lvwerra HF staff
Update app.py
13461c8
raw
history blame
No virus
2.72 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
INTRO = """# Harm's law
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots()
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=64, marker="x", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="x", s=64, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of compute optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}
## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B
Optimal datset size (tokens):\t {D_opt/Bn:.2f}
## Your setting trade-off:
Training compute overhead:\t {100*compute_overhead(kn, kd):.2f}%
Inference cost fraction:\t {kn*100:.2f}%"""
return text, fig
with gr.Blocks() as demo:
gr.Markdown(INTRO)
N = gr.Number(value=1, label="Model size (in B parameters):")
D = gr.Number(value=100, label="Dataset size (in B tokens):")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch()