Spaces:
Runtime error
Runtime error
File size: 2,967 Bytes
0533d8a a082eb0 0533d8a 1edea07 0533d8a ed4871a 0533d8a 2441f02 0533d8a 4cf381c 0533d8a 4795aa0 239a905 0533d8a a082eb0 13461c8 672c56e 3001b8c 0533d8a dea7b06 a082eb0 53e52ca a082eb0 5e71d7d 0533d8a e041ed4 5e71d7d 0533d8a 9612937 2441f02 5e71d7d 0533d8a dea7b06 2cddfc6 dea7b06 2cddfc6 9874fb4 afadb31 9874fb4 dea7b06 8fc5879 afadb31 8fc5879 2cddfc6 8fc5879 5e71d7d e041ed4 1edea07 2cddfc6 9874fb4 2cddfc6 e041ed4 5e71d7d 60d7aa9 e041ed4 5e71d7d c73cf94 0533d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
INTRO = """# Harm's law
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots()
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of Chinchilla optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Your specificied setting corresponds to the following training compute budget.
**Compute budget (TFLOPs): {C:.2E}**
## Chinchilla optimal:
If you are optimizeing for model performance and ignore inference cost this is the optimal setting for training:
**Optimal model size: {N_opt/Bn:.2f}B parametes**
**Optimal datset size: {D_opt/Bn:.2f}B tokens**
## Your setting trade-off:
Compared to the compute optimal model.
**Training compute overhead: {100*compute_overhead(kn, kd):.2f}%**
**Inference cost savings: {100 - kn*100:.2f}%** """
return text, fig
with gr.Blocks() as demo:
gr.Markdown(INTRO)
with gr.Row():
N = gr.Number(value=7, label="Model size (in B parameters):")
D = gr.Number(value=2000, label="Dataset size (in B tokens):")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch() |