harms-law / app.py
lvwerra's picture
lvwerra HF staff
Update app.py
71717e4
raw history blame
No virus
3.01 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
INTRO = """# Harm's law
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots(dpi=200, figsize=(5, 3))
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of Chinchilla optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
plt.tight_layout()
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Your specificied setting corresponds to the following training compute budget.
**Compute budget (TFLOPs): {C:.2E}**
## Chinchilla optimal:
If you are optimizeing for model performance and ignore inference cost this is the optimal setting for training:
**Optimal model size: {N_opt/Bn:.2f}B parametes**
**Optimal datset size: {D_opt/Bn:.2f}B tokens**
## Your setting trade-off:
Compared to the compute optimal model.
**Training compute overhead: {100*compute_overhead(kn, kd):.2f}%**
**Inference cost savings: {100 - kn*100:.2f}%** """
return text, fig
with gr.Blocks() as demo:
gr.Markdown(INTRO)
with gr.Row():
N = gr.Number(value=7, label="Model size (in B parameters):")
D = gr.Number(value=2000, label="Dataset size (in B tokens):")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch()