File size: 2,028 Bytes
0533d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2441f02
0533d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
239a905
0533d8a
 
 
5e71d7d
dea7b06
 
0533d8a
 
dea7b06
5e71d7d
0533d8a
e041ed4
 
5e71d7d
0533d8a
 
 
 
9612937
2441f02
0533d8a
9612937
5e71d7d
 
0533d8a
dea7b06
 
 
 
 
 
 
 
 
 
 
5e71d7d
e041ed4
 
 
 
 
 
5e71d7d
60d7aa9
e041ed4
5e71d7d
0533d8a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr 
import matplotlib.pyplot as plt
import numpy as np

### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283

Bn = 10**9

G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) 
###

def to_flops(N, D):
    return 6 * N * D

def n_opt(C):
    return G * ((C/6) ** (beta / (alpha+beta)))

def d_opt(C):
    return (1/G) * ((C/6) ** (alpha / (alpha+beta)))

def compute_kd(kn):
    frac = (A/B)*(G**(-alpha-beta))    
    kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
    return kd

def compute_overhead(kn, kd):
    return kn*kd - 1

### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2

kns = np.linspace(0.05, 2, 100)
overheads = []
for kn in np.linspace(0.2, 2, 100):
    kd = compute_kd(kn)
    overheads.append(compute_overhead(kn, kd)*100)

def plot_curve(kn, kd):
    fig = plt.figure()
    plt.plot(kns, overheads, color="black")
    plt.scatter([kn], [compute_overhead(kn, kd)*100], marker="D", markerfacecolor="red", markeredgecolor="black", label="You are here!")
    plt.xlabel("Fraction of compute optimal model size")
    plt.ylabel("Compute overhead (%)")
    plt.legend(loc="best")
    return fig


def compute(N, D):
    
    C = to_flops(N * Bn, D * Bn)
    N_opt = n_opt(C)
    D_opt = d_opt(C)

    kn = Bn*N/N_opt
    kd = compute_kd(kn)

    print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
    
    fig = plot_curve(kn, kd)

    text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}

## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B
Optimal datset size (tokens):\t {D_opt/Bn:.2f}

## Your setting trade-off:
Training compute overhead (%):\t {100*compute_overhead(kn, kd):.2f}
Inference cost fraction (%):\t {kn*100:.2f}"""
    return text, fig

with gr.Blocks() as demo:
    N = gr.Number(value=1, label="Model size (in B parameters)")
    D = gr.Number(value=100, label="Dataset size (in B tokens")
    button = gr.Button("Compute!")
    
    plot = gr.Plot(value=plt)
    md = gr.Markdown("")

    button.click(fn=compute, inputs=[N, D], outputs=[md, plot])

demo.launch()