Spaces:
Sleeping
Sleeping
File size: 2,715 Bytes
0533d8a a082eb0 0533d8a 1edea07 0533d8a ed4871a 0533d8a 2441f02 0533d8a 4cf381c 0533d8a 4795aa0 239a905 0533d8a a082eb0 13461c8 0533d8a dea7b06 a082eb0 53e52ca a082eb0 5e71d7d 0533d8a e041ed4 5e71d7d 0533d8a 9612937 2441f02 0533d8a 9612937 5e71d7d 0533d8a dea7b06 afadb31 dea7b06 4795aa0 afadb31 4795aa0 5e71d7d e041ed4 1edea07 4795aa0 e041ed4 5e71d7d 60d7aa9 e041ed4 5e71d7d c73cf94 0533d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
INTRO = """# Harm's law
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots()
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=64, marker="x", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="x", s=64, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of compute optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
return fig
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
fig = plot_curve(kn, kd)
text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}
## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B
Optimal datset size (tokens):\t {D_opt/Bn:.2f}
## Your setting trade-off:
Training compute overhead:\t {100*compute_overhead(kn, kd):.2f}%
Inference cost fraction:\t {kn*100:.2f}%"""
return text, fig
with gr.Blocks() as demo:
gr.Markdown(INTRO)
N = gr.Number(value=1, label="Model size (in B parameters):")
D = gr.Number(value=100, label="Dataset size (in B tokens):")
button = gr.Button("Compute!")
plot = gr.Plot(value=plt)
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch() |