File size: 2,967 Bytes
0533d8a
 
 
a082eb0
0533d8a
1edea07
 
 
 
 
 
0533d8a
 
 
 
 
 
 
 
 
 
 
ed4871a
0533d8a
 
 
 
 
 
 
 
 
2441f02
0533d8a
 
 
 
 
 
 
 
 
 
 
4cf381c
0533d8a
4795aa0
239a905
0533d8a
 
 
a082eb0
13461c8
672c56e
 
3001b8c
0533d8a
dea7b06
a082eb0
53e52ca
a082eb0
 
5e71d7d
0533d8a
e041ed4
 
5e71d7d
0533d8a
 
 
 
9612937
2441f02
5e71d7d
 
0533d8a
dea7b06
 
2cddfc6
 
 
dea7b06
 
2cddfc6
 
9874fb4
afadb31
9874fb4
dea7b06
 
8fc5879
afadb31
8fc5879
2cddfc6
8fc5879
5e71d7d
e041ed4
 
1edea07
2cddfc6
9874fb4
 
2cddfc6
e041ed4
 
5e71d7d
60d7aa9
e041ed4
5e71d7d
c73cf94
0533d8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr 
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator

INTRO = """# Harm's law 

The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. 
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). 
"""

### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283

Bn = 10**9

G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) 

### FUNCTIONS
def to_flops(N, D):
    return 6 * N * D

def n_opt(C):
    return G * ((C/6) ** (beta / (alpha+beta)))

def d_opt(C):
    return (1/G) * ((C/6) ** (alpha / (alpha+beta)))

def compute_kd(kn):
    frac = (A/B)*(G**(-alpha-beta))    
    kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
    return kd

def compute_overhead(kn, kd):
    return kn*kd - 1

### PRECOMPUTE CURVE:
kn_min = 0.2
kn_max = 2

kns = np.linspace(0.2, 2, 100)
overheads = []
for kn in kns:
    kd = compute_kd(kn)
    overheads.append(compute_overhead(kn, kd)*100)

def plot_curve(kn, kd):
    fig, ax = plt.subplots()
    plt.plot(kns, overheads, color="black", zorder=1)
    plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2)
    plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2)
    plt.xlabel("Fraction of Chinchilla optimal model size")
    plt.ylabel("Compute overhead (%)")
    plt.legend(loc="best")
    plt.grid(True, which="both")
    plt.grid(True, which="minor", alpha=0.5)
    ax.yaxis.set_minor_locator(MultipleLocator(10))

    return fig


def compute(N, D):
    
    C = to_flops(N * Bn, D * Bn)
    N_opt = n_opt(C)
    D_opt = d_opt(C)

    kn = Bn*N/N_opt
    kd = compute_kd(kn)
    
    fig = plot_curve(kn, kd)

    text = f"""\
## Compute:
Your specificied setting corresponds to the following training compute budget.

**Compute budget (TFLOPs): {C:.2E}**

## Chinchilla optimal:
If you are optimizeing for model performance and ignore inference cost this is the optimal setting for training:

**Optimal model size: {N_opt/Bn:.2f}B parametes**

**Optimal datset size: {D_opt/Bn:.2f}B tokens**

## Your setting trade-off:
Compared to the compute optimal model.

**Training compute overhead: {100*compute_overhead(kn, kd):.2f}%**

**Inference cost savings: {100 - kn*100:.2f}%** """
    return text, fig

with gr.Blocks() as demo:
    gr.Markdown(INTRO)
    with gr.Row():
        N = gr.Number(value=7, label="Model size (in B parameters):")
        D = gr.Number(value=2000, label="Dataset size (in B tokens):")
    
    button = gr.Button("Compute!")
    
    plot = gr.Plot(value=plt)
    md = gr.Markdown("")

    button.click(fn=compute, inputs=[N, D], outputs=[md, plot])
    demo.load(fn=compute, inputs=[N, D], outputs=[md, plot])
demo.launch()