File size: 8,103 Bytes
7b48c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133880
7b48c38
 
 
 
 
 
 
 
2133880
 
7b48c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52aac07
7b48c38
 
 
 
2133880
7b48c38
 
 
 
 
 
 
 
 
 
 
 
 
 
2133880
7b48c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import argparse
import math
from models import models


def get_GB(nbytes):
    return nbytes/(1024**3)


def vocab(bsz, seqlen, dmodel, vocab_dim):
    # assumes tied embeddings

    w = vocab_dim*dmodel
    emb = seqlen*bsz*dmodel
    emb_norm = seqlen*bsz*dmodel
    pos_emb = seqlen*bsz*dmodel
    out_emb = seqlen*bsz*vocab_dim
    softmax_emb = seqlen*bsz*vocab_dim

    model = w + dmodel
    grad = emb + emb_norm + pos_emb + out_emb + softmax_emb
    grad *= 1
    return model, grad


def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
                checkpoint=False, shared_groups=None):
    if dhid is None: dhid = 4*dmodel
    model = 0
    grad = 0
    for i in range(nlayers):
        m, g = transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=checkpoint)
        model += m
        grad += g

    if shared_groups is not None:
        model = model / nlayers * shared_groups

    m, g = vocab(bsz, seqlen, dmodel, vocab_type)
    model += m
    grad += g

    return model, grad

def layer_norm(bsz, seqlen, dmodel):
    w = dmodel
    x_grad = bsz*seqlen*dmodel
    return w, x_grad


def transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=False):
    model = 0
    grad = 0

    m, g = ffn(bsz, seqlen, dmodel, dhid, 'gelu')
    model += m
    grad += g*3

    m, g = attention_layer(bsz, seqlen, dmodel)
    model += m
    grad += g*5.0

    m, g = layer_norm(bsz, seqlen, dmodel)
    model += m
    grad += g*1.0

    if checkpoint:
        grad = bsz * seqlen * dmodel

    return model, grad

def attention_layer(bsz, seqlen, dmodel):
    w_proj = dmodel*3*dmodel
    w_out = dmodel*dmodel

    x_residual = bsz*seqlen*dmodel
    x_proj = bsz*seqlen*dmodel*3
    #x_proj_contiguous = bsz*seqlen*dmodel*3
    x_proj_contiguous = 0

    x_qscaled = bsz*seqlen*dmodel
    x_qk = bsz*seqlen*seqlen*2 # we need to store both input sequence directions for gradient computation
    x_softmax = bsz*seqlen*seqlen
    x_softmax_v = bsz*seqlen*dmodel*2 # we need to store both input sequence directions for gradient computation
    #x_out_contiguous = bsz*seqlen*dmodel
    x_out_contiguous = 0
    x_out = bsz*seqlen*dmodel

    model = w_proj + w_out
    grad = x_residual + x_proj + x_proj_contiguous + x_qscaled + x_qk + x_softmax + x_softmax_v + x_out_contiguous + x_out
    return model, grad



def ffn(bsz, seqlen, dmodel, dhid, func='relu'):
    # out = linear(relu(linear(x), inplace=True)) + x
    w1 = dmodel*dhid
    w2 = dhid*dmodel
    model = w1 + w2
    wgrad = model
    x1 = bsz*seqlen*dhid
    if func != 'relu': x1 *= 2 # inplace not possible with most other functions
    x2 = bsz*seqlen*dmodel
    residual = bsz*seqlen*dmodel
    grad = x1 + x2 + residual

    return model, grad


OPTIMIZERS = ['adam', 'adafactor', 'adafactor-fac-only', '8-bit-adam', '16-bit-adam']


def parse_args(args=None):
    parser = argparse.ArgumentParser('Memory calculator')

    parser.add_argument('--nlayers', type=int, help='The number of transformer layers.')
    parser.add_argument('--bsz', type=int, default=1, help='The batch size. Default: 2')
    parser.add_argument('--seqlen', type=int, help='The sequence length.')
    parser.add_argument('--dmodel', type=int, help='The core model size.')
    parser.add_argument('--dhid', type=int, default=None,
                        help='The hidden size of the FFN layer. Default: 4x model size.')
    parser.add_argument('--fp16-level', type=str, default='O1',
                        help='FP16-level to use. O0 = FP32; O1 = mixed-precision (16+32); O3 = fp16. Default: O1.')
    parser.add_argument('--model', default='', choices=list(models.keys()), help='Predefined NLP transformer models')
    parser.add_argument('--optimizer', default='adam', choices=OPTIMIZERS, help='The optimizer to use.')
    parser.add_argument('--vocab_size', type=int, default=None, help='The vocabulary to use.')
    parser.add_argument('--offload', action='store_true', help='Whether to use optimizer offload.')
    parser.add_argument('--ngpus', type=int, default=1, help='The number of gpus. Default: 1')
    parser.add_argument('--zero', type=int, default=0,
                        help='The ZeRO level (1 optimizer, 2 optimizer+weights, 3 everything. Default: 1')
    parser.add_argument('--shared_groups', type=int, default=None, help='Number of shared layer groups (as in ALBERT). Defaults to no sharing.')
    parser.add_argument('--checkpoint', action='store_true', help='Use gradient checkpointing.')

    return parser.parse_args(args)


def calculate_memory(args):
    if args.model != '':
        if args.model not in models:
            raise ValueError(f'{args.model} is not supported')
        else:
            for key, value in models[args.model].items():
                if getattr(args, key, None) is None:
                    setattr(args, key, value)

    model, grad = transformer(args.bsz, args.seqlen, args.dmodel, args.nlayers, args.vocab_size, args.dhid, args.checkpoint, args.shared_groups)
    parameters = model

    if args.optimizer == 'adam':
        optim = 8*model
    elif args.optimizer == '8-bit-adam':
        optim = 2*model
    elif args.optimizer in ['16-bit-adam', 'adafactor']:
        optim = 4*model
    elif args.optimizer in ['adafactor-fac-only']:
        optim = math.log(model)

    if args.fp16_level == 'O0':
        # fp32 weights
        wgrad = 4*model
        model = 4*model
        grad = 4*grad # fp32
    elif args.fp16_level in ['O1', 'O2']:
        # fp16 weights + fp32 master weights
        wgrad = 2*model
        model = 4*model + (2*model)
        grad = 2*grad # fp16
    elif args.fp16_level == 'O3':
        wgrad = 2*model
        model = 2*model #fp16
        grad = 2*grad # fp32

    model = get_GB(model)
    grad = get_GB(grad)
    optim = get_GB(optim)
    wgrad = get_GB(wgrad)

    cpu_mem = 0
    overhead = 0

    if args.zero == 1:
        if not args.offload:
            # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
            overhead += optim/25

        optim = optim / args.ngpus
    elif args.zero == 2:
        if not args.offload:
            # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
            overhead += optim/25
            overhead += wgrad/25

        optim = optim / args.ngpus
        wgrad = wgrad / args.ngpus
    elif args.zero == 3:
        if not args.offload:
            # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
            overhead += optim/25
            overhead += model/25
            overhead += wgrad/25

        optim = optim / args.ngpus
        model = model / args.ngpus
        wgrad = wgrad / args.ngpus


    if args.offload:
        cpu_mem = optim + wgrad
        optim = 0
        wgrad = 0
        if args.ngpus <= 2:
            # 12 GB/s for PCIe 3.0 and 1-2x GPU setup (16 lanes, 16 GB/s theoretical)
            overhead = cpu_mem/12
        else:
            # 6 GB/s for PCIe 3.0 and 4x GPU setup
            overhead = cpu_mem/6


    total_mem = model + grad + optim + wgrad
    return locals()


if __name__ == '__main__':
    args = parse_args()
    mem = calculate_memory(args)
    print('')
    print(f'Model: {args.model} with batch size {args.bsz} and sequence length {args.seqlen} and a total of {mem["parameters"]/1e9:.4f}B parameters.')
    print('='*80)
    print('Weight memory:           {0:.2f} GB ({1:.2f}%)'.format(mem['model'], 100*mem['model']/mem['total_mem']))
    print('Weight gradient memory:  {0:.2f} GB ({1:.2f}%)'.format(mem['wgrad'], 100*mem['wgrad']/mem['total_mem']))
    print('Input gradient memory:   {0:.2f} GB ({1:.2f}%)'.format(mem['grad'], 100*mem['grad']/mem['total_mem']))
    print('Optimizer memory:        {0:.2f} GB ({1:.2f}%)'.format(mem['optim'], 100*mem['optim']/mem['total_mem']))
    print('Total GPU memory:        {0:.2f} GB'.format(mem['total_mem']))
    if mem['cpu_mem'] > 0:
        print('Total CPU memory:        {0:.2f} GB'.format(mem['cpu_mem']))
    if mem['overhead'] > 0:
        print('Overhead: {0:.2f} seconds per update (can be partially overlapped with compute)'.format(mem['overhead']))