File size: 2,410 Bytes
0533d8a
 
 
 
1edea07
 
 
 
 
 
0533d8a
 
 
 
 
 
 
 
 
 
 
ed4871a
0533d8a
 
 
 
 
 
 
 
 
2441f02
0533d8a
 
 
 
 
 
 
 
 
 
 
4cf381c
0533d8a
4795aa0
239a905
0533d8a
 
 
5e71d7d
dea7b06
afadb31
0533d8a
 
dea7b06
5e71d7d
0533d8a
e041ed4
 
5e71d7d
0533d8a
 
 
 
9612937
2441f02
0533d8a
9612937
5e71d7d
 
0533d8a
dea7b06
 
 
 
 
 
afadb31
dea7b06
 
 
4795aa0
afadb31
4795aa0
5e71d7d
e041ed4
 
1edea07
4795aa0
 
e041ed4
 
5e71d7d
60d7aa9
e041ed4
5e71d7d
c73cf94
0533d8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr 
import matplotlib.pyplot as plt
import numpy as np

INTRO = """# Harm's law 

The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. 
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). 
"""

### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283

Bn = 10**9

G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) 

### FUNCTIONS
def to_flops(N, D):
    return 6 * N * D

def n_opt(C):
    return G * ((C/6) ** (beta / (alpha+beta)))

def d_opt(C):
    return (1/G) * ((C/6) ** (alpha / (alpha+beta)))

def compute_kd(kn):
    frac = (A/B)*(G**(-alpha-beta))    
    kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
    return kd

def compute_overhead(kn, kd):
    return kn*kd - 1

### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2

kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
    kd = compute_kd(kn)
    overheads.append(compute_overhead(kn, kd)*100)

def plot_curve(kn, kd):
    fig = plt.figure()
    plt.plot(kns, overheads, color="black")
    plt.scatter([kn], [compute_overhead(kn, kd)*100], marker="x", c="red", label="You are here!")
    plt.xlabel("Fraction of compute optimal model size")
    plt.ylabel("Compute overhead (%)")
    plt.legend(loc="best")
    return fig


def compute(N, D):
    
    C = to_flops(N * Bn, D * Bn)
    N_opt = n_opt(C)
    D_opt = d_opt(C)

    kn = Bn*N/N_opt
    kd = compute_kd(kn)

    print(N, D, N_opt/Bn, D_opt/Bn, kn, kd)
    
    fig = plot_curve(kn, kd)

    text = f"""\
## Compute:
Compute budget (TFLOPs): {C:.2E}

## Chinchilla optimal:
Optimal model size:\t\t {N_opt/Bn:.2f}B

Optimal datset size (tokens):\t {D_opt/Bn:.2f}

## Your setting trade-off:
Training compute overhead:\t {100*compute_overhead(kn, kd):.2f}%

Inference cost fraction:\t {kn*100:.2f}%"""
    return text, fig

with gr.Blocks() as demo:
    gr.Markdown(INTRO)
    N = gr.Number(value=1, label="Model size (in B parameters):")
    D = gr.Number(value=100, label="Dataset size (in B tokens):")
    button = gr.Button("Compute!")
    
    plot = gr.Plot(value=plt)
    md = gr.Markdown("")

    button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
    demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch()