File size: 4,124 Bytes
59b2635 4f09294 59b2635 761dc7e 59b2635 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
HARM_INTRO = """
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### GPU specs:
A100_flops = 312e12
H100_flops = 990e12
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.18
kn_max = 2
kns = np.linspace(kn_min, kn_max, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots(dpi=200, figsize=(5, 3))
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of Chinchilla optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
plt.tight_layout()
return fig
def compute(N, D, gpu_type, gpu_util, n_gpus, gpu_price):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
fig = plot_curve(kn, kd)
gpu_util = gpu_util/100
if gpu_type=="H100":
gpu_flops = H100_flops * gpu_util
else:
gpu_flops = A100_flops * gpu_util
gpu_hours = (C / (gpu_flops * 3600))
text = f"""\
## Training summary
|Training compute| Training cost | Training time | Total GPU hours |
|:----|:-------|:-------|:-------|
|{C:.2E} TFLOPs | ${(gpu_hours * gpu_price)/1e6:.2f}M | {gpu_hours/(24*n_gpus):.2f} days | {gpu_hours/1_000_000:.2f}M |
## Chinchilla and Training/Inference Trade-off
Optimal model/dataset size for training compute and how it translates to training overhead and inference savings according to Harm's law
|Chinchilla optimal model | Chinchilla optimal dataset | Training overhead | Inference savings|
|:----|:-------|:----|:-------|
| {N_opt/Bn:.2f}B parameters | {D_opt/Bn:.2f}B tokens | {100*compute_overhead(kn, kd):.2f}%| {100 - kn*100:.2f}% |
"""
return text, fig
with gr.Blocks() as demo:
gr.Markdown("# Train LLMs")
gr.Markdown("## Training configuration")
with gr.Row():
N = gr.Number(value=7, label="Model size (in B parameters):")
D = gr.Number(value=2000, label="Dataset size (in B tokens):")
gr.Markdown("## Cluster configuration")
with gr.Row():
n_gpus = gr.Number(value=1000, label="Number of GPUs")
gpu_type = gr.Dropdown(choices=["A100", "H100"], value="H100", label="GPU type")
gpu_util = gr.Number(value=50, label="% GPU utilization")
gpu_price = gr.Number(value=3.00, label="$/GPU/Hour")
button = gr.Button("Compute!")
with gr.Row():
with gr.Column():
gr.Markdown("## Harm's law")
plot = gr.Plot(value=plt)
gr.Markdown(HARM_INTRO)
with gr.Column():
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot])
demo.launch() |