harms-law / app.py
lvwerra's picture
lvwerra HF staff
Update app.py
e041ed4
raw
history blame
1.6 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
###
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def get_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2
kns = np.linspace(0.05, 2, 100)
overheads = []
for kn in np.linspace(0.2, 2, 100):
kd = get_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
plt.plot(kns, overheads)
plt.scatter([kn], [kd])
plt.xlabel("Fraction of compute optimal model size")
plt.ylabel("Compute overhead (%)")
def compute(N, D):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = N/N_opt
plot_curve(kn, 100*overhead(kn, get_kd(kn)))
text = f"""Compute budget (TFLOPs): {C:.2E}\nTraining compute overhead (%): {100*overhead(kn, get_kd(kn)):.2f}\nInference cost fraction (%): {kn*100:.2f}"""
return text
with gr.Blocks() as demo:
N = gr.Number(value=1, label="Model size (in B parameters)")
D = gr.Number(value=100, label="Dataset size (in B tokens")
button = gr.Button("Compute!")
gr.Plot(value=plt)
md = gr.Markdown(f"""{}""")
button.click(fn=, inputs=[N, D], ouptus=[md])
demo.launch()