import gradio as gr import matplotlib.pyplot as plt import numpy as np INTRO = """# Harm's law The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). """ ### CHINCHILLA PARAMS: E = 1.62 A = 406.4 B = 410.7 alpha = 0.336 beta = 0.283 Bn = 10**9 G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) ### FUNCTIONS def to_flops(N, D): return 6 * N * D def n_opt(C): return G * ((C/6) ** (beta / (alpha+beta))) def d_opt(C): return (1/G) * ((C/6) ** (alpha / (alpha+beta))) def compute_kd(kn): frac = (A/B)*(G**(-alpha-beta)) kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) return kd def compute_overhead(kn, kd): return kn*kd - 1 ### PRECOMPUTE CURVE: kn_min = 0.2 kn_max = 2 kns = np.linspace(0.2, 2, 100) overheads = [] for kn in kns: kd = compute_kd(kn) overheads.append(compute_overhead(kn, kd)*100) def plot_curve(kn, kd): fig = plt.figure() plt.plot(kns, overheads, color="black") plt.scatter([kn], [compute_overhead(kn, kd)*100], marker="x", c="red", label="You are here!") plt.xlabel("Fraction of compute optimal model size") plt.ylabel("Compute overhead (%)") plt.legend(loc="best") return fig def compute(N, D): C = to_flops(N * Bn, D * Bn) N_opt = n_opt(C) D_opt = d_opt(C) kn = Bn*N/N_opt kd = compute_kd(kn) print(N, D, N_opt/Bn, D_opt/Bn, kn, kd) fig = plot_curve(kn, kd) text = f"""\ ## Compute: Compute budget (TFLOPs): {C:.2E} ## Chinchilla optimal: Optimal model size:\t\t {N_opt/Bn:.2f}B Optimal datset size (tokens):\t {D_opt/Bn:.2f} ## Your setting trade-off: Training compute overhead:\t {100*compute_overhead(kn, kd):.2f}% Inference cost fraction:\t {kn*100:.2f}%""" return text, fig with gr.Blocks() as demo: gr.Markdown(INTRO) N = gr.Number(value=1, label="Model size (in B parameters):") D = gr.Number(value=100, label="Dataset size (in B tokens):") button = gr.Button("Compute!") plot = gr.Plot(value=plt) md = gr.Markdown("") button.click(fn=compute, inputs=[N, D], outputs=[md, plot]) demo.launch()