justheuristic commited on
Commit
7b48c38
1 Parent(s): edcba35
.github/workflows/sync_to_hub.yaml CHANGED
@@ -17,4 +17,4 @@ jobs:
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
- run: git push https://training-transformers-together:$HF_TOKEN@huggingface.co/spaces/training-transformers-together/dashboard-embedded main --force
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://training-transformers-together:$HF_TOKEN@huggingface.co/spaces/training-transformers-together/calc main --force
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Mini-dashboard
3
  emoji: ⚡
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
1
  ---
2
+ title: Memory calculator
3
  emoji: ⚡
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
app.py CHANGED
@@ -5,41 +5,47 @@ If you're not a hedgehog, you shouldn't reuse this code. Use this instead: https
5
 
6
  import streamlit as st
7
 
8
- from dashboard_utils.main_metrics import get_main_metrics
9
-
10
- st.set_page_config(page_title="Training Transformers Together - Mini-Dashboard", layout="wide")
11
  st.markdown("""<style>
12
  .reportview-container {
13
  top: -80px;
14
  }
15
  </style>""", unsafe_allow_html=True)
16
- source = get_main_metrics()
17
- st.vega_lite_chart(
18
- source, {
19
- "height": 200,
20
- "title": {"text": "Training DALL-E with volunteers", "dy": 7},
21
- # ^-- WARNING: do not use long titles, otherwise vega collapses on small screens
22
- "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
23
- "description": "Current training progress",
24
- "encoding": {"x": {"field": "wall time", "type": "temporal"}},
25
- "config": {"axisX": {"labelAngle": -40}},
26
- "resolve": {"scale": {"y": "independent"}},
27
- "layer": [
28
- {
29
- "mark": {"type": "line", "point": {"tooltip": True, "filled": False, "strokeOpacity": 0},
30
- "color": "#85A9C5"},
31
- "encoding": {
32
- "y": {"field": "training loss", "type": "quantitative", "axis": {"titleColor": "#85A9C5"},
33
- "scale": {"zero": False}}},
34
- },
35
- {
36
- "mark": {"type": "line", "point": {"tooltip": True, "filled": False, "strokeOpacity": 0.0},
37
- "color": "#85C5A6", "opacity": 0.5},
38
- "encoding": {
39
- "y": {"field": "active participants", "type": "quantitative",
40
- "axis": {"titleColor": "#85C5A6"}}},
41
- },
42
- ],
43
- },
44
- use_container_width=True, # breaks on <600px screens
45
- )
 
 
 
 
 
 
5
 
6
  import streamlit as st
7
 
8
+ import mem_calc
9
+ from models import models
10
+ st.set_page_config(page_title="Memory calculator", layout="centered")
11
  st.markdown("""<style>
12
  .reportview-container {
13
  top: -80px;
14
  }
15
  </style>""", unsafe_allow_html=True)
16
+
17
+ models = list(models.keys()) # respect the original order because py37
18
+ model = st.selectbox('Model architecture', models, index=models.index("gpt2-l"))
19
+
20
+ optimizers_names = ('32-bit', '16-bit', '8-bit', 'factorized')
21
+ optimizers_values = ['adam', '16-bit-adam', '8-bit-adam', 'adafactor']
22
+ optimizer = st.radio('Adam / LAMB states', optimizers_names)
23
+ checkpoint = st.checkbox("Gradient checkpointing", value=True)
24
+ offload = st.checkbox("Offload optimizer", value=False)
25
+ share_params = st.checkbox("Share parameters", value=False)
26
+
27
+ with st.expander("More options"):
28
+
29
+ precisions_names = ('Full', 'Mixed ("O1")', 'Pure 16-bit')
30
+ precisions_values = ('O0', 'O1', 'O3')
31
+ precision = st.selectbox('Precision', precisions_names, index=1)
32
+
33
+ vocab_size = int(st.number_input('Vocabulary size', min_value=1, step=1, value=50257, format="%i"))
34
+
35
+ args = mem_calc.parse_args(f"""
36
+ --model {model} --vocab_size {vocab_size} --optimizer {optimizers_values[optimizers_names.index(optimizer)]}
37
+ {'--checkpoint' if checkpoint else ''} {'--offload' if offload else ''} {'--albert' if share_params else ''}
38
+ --fp16-level {precisions_values[precisions_names.index(precision)]}
39
+ """.split())
40
+
41
+
42
+ memory = mem_calc.calculate_memory(args)
43
+
44
+ cols = st.columns(3)
45
+ cols[0].metric("Parameters (GPU)", f"{memory['model']:.2f} GB", f"{memory['model']/memory['total_mem'] * 100:.2f} %", delta_color="off")
46
+ cols[1].metric(f"Optimizer ({'GPU' if offload else 'CPU'})", f"{memory['optim']:.2f} GB", f"{memory['optim']/memory['total_mem'] * 100:.2f} %", delta_color="off")
47
+ cols[2].metric("Activations (GPU)", f"{memory['grad']:.2f} GB", f"{memory['grad']/memory['total_mem'] * 100:.2f} %", delta_color="off")
48
+ cols = st.columns(3)
49
+ cols[0].metric("GPU total", f"{memory['total_mem']:.2f} GB")
50
+ cols[1].metric("Offloaded to RAM", f"{memory['cpu_mem']:.2f} GB")
51
+ cols[2].metric("Communication overhead", f"{memory['overhead'] * 1000:.2f} ms")
dashboard_utils/main_metrics.py DELETED
@@ -1,33 +0,0 @@
1
- import datetime
2
-
3
- import streamlit as st
4
- import pandas as pd
5
-
6
- import wandb
7
-
8
- from dashboard_utils.time_tracker import _log, simple_time_tracker
9
-
10
- WANDB_REPO = "learning-at-home/dalle-hivemind"
11
- CACHE_TTL = 120 # note: in the text, we claim that this plot is updated every few minutes
12
-
13
-
14
- @st.cache(ttl=CACHE_TTL)
15
- @simple_time_tracker(_log)
16
- def get_main_metrics():
17
- wandb.login(anonymous="must")
18
- api = wandb.Api()
19
- runs = api.runs(WANDB_REPO)
20
- run = runs[0]
21
- history = run.history(keys=["step", "loss", "alive peers", "_timestamp"])
22
-
23
- steps = []
24
- losses = []
25
- alive_peers = []
26
- dates = []
27
- for _, row in history.iterrows():
28
- steps.append(row["step"])
29
- losses.append(row["loss"])
30
- alive_peers.append(row["alive peers"])
31
- dates.append(datetime.datetime.utcfromtimestamp(row["_timestamp"]))
32
-
33
- return pd.DataFrame({"steps": steps, "training loss": losses, "active participants": alive_peers, "wall time": dates})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dashboard_utils/time_tracker.py DELETED
@@ -1,32 +0,0 @@
1
- from functools import wraps
2
- from time import time
3
-
4
-
5
- def simple_time_tracker(log_fun):
6
- def _simple_time_tracker(fn):
7
- @wraps(fn)
8
- def wrapped_fn(*args, **kwargs):
9
- start_time = time()
10
-
11
- try:
12
- result = fn(*args, **kwargs)
13
- finally:
14
- elapsed_time = time() - start_time
15
-
16
- # log the result
17
- log_fun(
18
- {
19
- "function_name": fn.__name__,
20
- "total_time": elapsed_time,
21
- }
22
- )
23
-
24
- return result
25
-
26
- return wrapped_fn
27
-
28
- return _simple_time_tracker
29
-
30
-
31
- def _log(message):
32
- print("[SimpleTimeTracker] {function_name} {total_time:.3f}".format(**message))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mem_calc.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ from models import models
4
+
5
+
6
+ def get_GB(nbytes):
7
+ return nbytes/(1024**3)
8
+
9
+
10
+ def vocab(bsz, seqlen, dmodel, vocab_dim):
11
+ # assumes tied embeddings
12
+
13
+ w = vocab_dim*dmodel
14
+ emb = seqlen*bsz*dmodel
15
+ emb_norm = seqlen*bsz*dmodel
16
+ pos_emb = seqlen*bsz*dmodel
17
+ out_emb = seqlen*bsz*vocab_dim
18
+ softmax_emb = seqlen*bsz*vocab_dim
19
+
20
+ model = w + dmodel
21
+ grad = emb + emb_norm + pos_emb + out_emb + softmax_emb
22
+ grad *= 1
23
+ return model, grad
24
+
25
+
26
+ def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
27
+ checkpoint=False, albert=False):
28
+ if dhid is None: dhid = 4*dmodel
29
+ model = 0
30
+ grad = 0
31
+ for i in range(nlayers):
32
+ m, g = transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=checkpoint)
33
+ model += m
34
+ grad += g
35
+
36
+ if albert:
37
+ model = model / nlayers
38
+
39
+ m, g = vocab(bsz, seqlen, dmodel, vocab_type)
40
+ model += m
41
+ grad += g
42
+
43
+ return model, grad
44
+
45
+ def layer_norm(bsz, seqlen, dmodel):
46
+ w = dmodel
47
+ x_grad = bsz*seqlen*dmodel
48
+ return w, x_grad
49
+
50
+
51
+ def transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=False):
52
+ model = 0
53
+ grad = 0
54
+
55
+ m, g = ffn(bsz, seqlen, dmodel, dhid, 'gelu')
56
+ model += m
57
+ grad += g*3
58
+
59
+ m, g = attention_layer(bsz, seqlen, dmodel)
60
+ model += m
61
+ grad += g*5.0
62
+
63
+ m, g = layer_norm(bsz, seqlen, dmodel)
64
+ model += m
65
+ grad += g*1.0
66
+
67
+ if checkpoint:
68
+ grad = bsz * seqlen * dmodel
69
+
70
+ return model, grad
71
+
72
+ def attention_layer(bsz, seqlen, dmodel):
73
+ w_proj = dmodel*3*dmodel
74
+ w_out = dmodel*dmodel
75
+
76
+ x_residual = bsz*seqlen*dmodel
77
+ x_proj = bsz*seqlen*dmodel*3
78
+ #x_proj_contiguous = bsz*seqlen*dmodel*3
79
+ x_proj_contiguous = 0
80
+
81
+ x_qscaled = bsz*seqlen*dmodel
82
+ x_qk = bsz*seqlen*seqlen*2 # we need to store both input sequence directions for gradient computation
83
+ x_softmax = bsz*seqlen*seqlen
84
+ x_softmax_v = bsz*seqlen*dmodel*2 # we need to store both input sequence directions for gradient computation
85
+ #x_out_contiguous = bsz*seqlen*dmodel
86
+ x_out_contiguous = 0
87
+ x_out = bsz*seqlen*dmodel
88
+
89
+ model = w_proj + w_out
90
+ grad = x_residual + x_proj + x_proj_contiguous + x_qscaled + x_qk + x_softmax + x_softmax_v + x_out_contiguous + x_out
91
+ return model, grad
92
+
93
+
94
+
95
+ def ffn(bsz, seqlen, dmodel, dhid, func='relu'):
96
+ # out = linear(relu(linear(x), inplace=True)) + x
97
+ w1 = dmodel*dhid
98
+ w2 = dhid*dmodel
99
+ model = w1 + w2
100
+ wgrad = model
101
+ x1 = bsz*seqlen*dhid
102
+ if func != 'relu': x1 *= 2 # inplace not possible with most other functions
103
+ x2 = bsz*seqlen*dmodel
104
+ residual = bsz*seqlen*dmodel
105
+ grad = x1 + x2 + residual
106
+
107
+ return model, grad
108
+
109
+
110
+ OPTIMIZERS = ['adam', 'adafactor', 'adafactor-fac-only', '8-bit-adam', '16-bit-adam']
111
+
112
+
113
+ def parse_args(args=None):
114
+ parser = argparse.ArgumentParser('Memory calculator')
115
+
116
+ parser.add_argument('--nlayers', type=int, help='The number of transformer layers.')
117
+ parser.add_argument('--bsz', type=int, default=1, help='The batch size. Default: 2')
118
+ parser.add_argument('--seqlen', type=int, help='The sequence length.')
119
+ parser.add_argument('--dmodel', type=int, help='The core model size.')
120
+ parser.add_argument('--dhid', type=int, default=None,
121
+ help='The hidden size of the FFN layer. Default: 4x model size.')
122
+ parser.add_argument('--fp16-level', type=str, default='O1',
123
+ help='FP16-level to use. O0 = FP32; O1 = mixed-precision (16+32); O3 = fp16. Default: O1.')
124
+ parser.add_argument('--model', default='', choices=list(models.keys()), help='Predefined NLP transformer models')
125
+ parser.add_argument('--optimizer', default='adam', choices=OPTIMIZERS, help='The optimizer to use.')
126
+ parser.add_argument('--vocab_size', type=int, default=50257, help='The vocabulary to use.')
127
+ parser.add_argument('--offload', action='store_true', help='Whether to use optimizer offload.')
128
+ parser.add_argument('--ngpus', type=int, default=1, help='The number of gpus. Default: 1')
129
+ parser.add_argument('--zero', type=int, default=0,
130
+ help='The ZeRO level (1 optimizer, 2 optimizer+weights, 3 everything. Default: 1')
131
+ parser.add_argument('--albert', action='store_true', help='Use parameter sharing.')
132
+ parser.add_argument('--checkpoint', action='store_true', help='Use gradient checkpointing.')
133
+
134
+ return parser.parse_args(args)
135
+
136
+
137
+ def calculate_memory(args):
138
+ if args.model != '':
139
+ if args.model not in models:
140
+ raise ValueError(f'{args.model} is not supported')
141
+ else:
142
+ for key, value in models[args.model].items():
143
+ if getattr(args, key, None) is None:
144
+ setattr(args, key, value)
145
+
146
+ model, grad = transformer(args.bsz, args.seqlen, args.dmodel, args.nlayers, args.vocab_size, args.dhid, args.checkpoint, args.albert)
147
+ parameters = model
148
+
149
+ if args.optimizer == 'adam':
150
+ optim = 8*model
151
+ elif args.optimizer == '8-bit-adam':
152
+ optim = 2*model
153
+ elif args.optimizer in ['16-bit-adam', 'adafactor']:
154
+ optim = 4*model
155
+ elif args.optimizer in ['adafactor-fac-only']:
156
+ optim = math.log(model)
157
+
158
+ if args.fp16_level == 'O0':
159
+ # fp32 weights
160
+ wgrad = 4*model
161
+ model = 4*model
162
+ grad = 4*grad # fp32
163
+ elif args.fp16_level in ['O1', 'O2']:
164
+ # fp16 weights + fp32 master weights
165
+ wgrad = 2*model
166
+ model = 4*model + (2*model)
167
+ grad = 2*grad # fp16
168
+ elif args.fp16_level == 'O3':
169
+ wgrad = 2*model
170
+ model = 2*model #fp16
171
+ grad = 2*grad # fp32
172
+
173
+ model = get_GB(model)
174
+ grad = get_GB(grad)
175
+ optim = get_GB(optim)
176
+ wgrad = get_GB(wgrad)
177
+
178
+ cpu_mem = 0
179
+ overhead = 0
180
+
181
+ if args.zero == 1:
182
+ if not args.offload:
183
+ # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
184
+ overhead += optim/25
185
+
186
+ optim = optim / args.ngpus
187
+ elif args.zero == 2:
188
+ if not args.offload:
189
+ # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
190
+ overhead += optim/25
191
+ overhead += wgrad/25
192
+
193
+ optim = optim / args.ngpus
194
+ wgrad = wgrad / args.ngpus
195
+ elif args.zero == 3:
196
+ if not args.offload:
197
+ # assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
198
+ overhead += optim/25
199
+ overhead += model/25
200
+ overhead += wgrad/25
201
+
202
+ optim = optim / args.ngpus
203
+ model = model / args.ngpus
204
+ wgrad = wgrad / args.ngpus
205
+
206
+
207
+ if args.offload:
208
+ cpu_mem = optim + wgrad
209
+ optim = 0
210
+ wgrad = 0
211
+ if args.ngpus <= 2:
212
+ # 12 GB/s for PCIe 3.0 and 1-2x GPU setup (16 lanes, 16 GB/s theoretical)
213
+ overhead = cpu_mem/12
214
+ else:
215
+ # 6 GB/s for PCIe 3.0 and 4x GPU setup
216
+ overhead = cpu_mem/6
217
+
218
+
219
+ total_mem = model + grad + optim + wgrad
220
+ return locals()
221
+
222
+
223
+ if __name__ == '__main__':
224
+ args = parse_args()
225
+ mem = calculate_memory(args)
226
+ print('')
227
+ print(f'Model: {args.model} with batch size {args.bsz} and sequence length {args.seqlen} and a total of {mem["parameters"]/1e9:.4f}B parameters.')
228
+ print('='*80)
229
+ print('Weight memory: {0:.2f} GB ({1:.2f}%)'.format(mem['model'], 100*mem['model']/mem['total_mem']))
230
+ print('Weight gradient memory: {0:.2f} GB ({1:.2f}%)'.format(mem['wgrad'], 100*mem['wgrad']/mem['total_mem']))
231
+ print('Input gradient memory: {0:.2f} GB ({1:.2f}%)'.format(mem['grad'], 100*mem['grad']/mem['total_mem']))
232
+ print('Optimizer memory: {0:.2f} GB ({1:.2f}%)'.format(mem['optim'], 100*mem['optim']/mem['total_mem']))
233
+ print('Total GPU memory: {0:.2f} GB'.format(mem['total_mem']))
234
+ if mem['cpu_mem'] > 0:
235
+ print('Total CPU memory: {0:.2f} GB'.format(mem['cpu_mem']))
236
+ if mem['overhead'] > 0:
237
+ print('Overhead: {0:.2f} seconds per update (can be partially overlapped with compute)'.format(mem['overhead']))
models.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models = {}
2
+ models['bert-s'] = {}
3
+ models['bert-s']['seqlen'] = 512
4
+ models['bert-s']['dmodel'] = 768
5
+ models['bert-s']['dhidden'] = 3072
6
+ models['bert-s']['nlayers'] = 12
7
+
8
+ models['bert-l'] = {}
9
+ models['bert-l']['seqlen'] = 512
10
+ models['bert-l']['dmodel'] = 1024
11
+ models['bert-l']['dhidden'] = 4096
12
+ models['bert-l']['nlayers'] = 24
13
+
14
+ models['t5-3b'] = {}
15
+ models['t5-3b']['seqlen'] = 512
16
+ models['t5-3b']['dmodel'] = 1024
17
+ models['t5-3b']['dhidden'] = 16384
18
+ models['t5-3b']['nlayers'] = 48
19
+
20
+ models['t5-11b'] = {}
21
+ models['t5-11b']['seqlen'] = 512
22
+ models['t5-11b']['dmodel'] = 1024
23
+ models['t5-11b']['dhidden'] = 64*1024
24
+ models['t5-11b']['nlayers'] = 48
25
+
26
+ models['gpt2-s'] = {}
27
+ models['gpt2-s']['seqlen'] = 1024
28
+ models['gpt2-s']['dmodel'] = 768
29
+ models['gpt2-s']['dhidden'] = 768*4
30
+ models['gpt2-s']['nlayers'] = 12
31
+
32
+ models['gpt2-m'] = {}
33
+ models['gpt2-m']['seqlen'] = 1024
34
+ models['gpt2-m']['dmodel'] = 1024
35
+ models['gpt2-m']['dhidden'] = 1024*4
36
+ models['gpt2-m']['nlayers'] = 24
37
+
38
+ models['gpt2-l'] = {}
39
+ models['gpt2-l']['seqlen'] = 1024
40
+ models['gpt2-l']['dmodel'] = 1280
41
+ models['gpt2-l']['dhidden'] = 1280*4
42
+ models['gpt2-l']['nlayers'] = 36
43
+
44
+ models['gpt2-xl'] = {}
45
+ models['gpt2-xl']['seqlen'] = 1024
46
+ models['gpt2-xl']['dmodel'] = 1600
47
+ models['gpt2-xl']['dhidden'] = 1600*4
48
+ models['gpt2-xl']['nlayers'] = 48
49
+
50
+
51
+ models['gpt3-s'] = {}
52
+ models['gpt3-s']['seqlen'] = 2048
53
+ models['gpt3-s']['dmodel'] = 768
54
+ models['gpt3-s']['dhidden'] = 768*4
55
+ models['gpt3-s']['nlayers'] = 12
56
+
57
+ models['gpt3-m'] = {}
58
+ models['gpt3-m']['seqlen'] = 2048
59
+ models['gpt3-m']['dmodel'] = 1024
60
+ models['gpt3-m']['dhidden'] = 1024*4
61
+ models['gpt3-m']['nlayers'] = 24
62
+
63
+ models['gpt3-l'] = {}
64
+ models['gpt3-l']['seqlen'] = 2048
65
+ models['gpt3-l']['dmodel'] = 1536
66
+ models['gpt3-l']['dhidden'] = 1536*4
67
+ models['gpt3-l']['nlayers'] = 24
68
+
69
+ models['gpt3-xl'] = {}
70
+ models['gpt3-xl']['seqlen'] = 2048
71
+ models['gpt3-xl']['dmodel'] = 2560
72
+ models['gpt3-xl']['dhidden'] = 2560*4
73
+ models['gpt3-xl']['nlayers'] = 24
74
+
75
+ models['gpt3-3b'] = {}
76
+ models['gpt3-3b']['seqlen'] = 2048
77
+ models['gpt3-3b']['dmodel'] = 2560
78
+ models['gpt3-3b']['dhidden'] = 2560*4
79
+ models['gpt3-3b']['nlayers'] = 32
80
+
81
+ models['gpt3-7b'] = {}
82
+ models['gpt3-7b']['seqlen'] = 2048
83
+ models['gpt3-7b']['dmodel'] = 4096
84
+ models['gpt3-7b']['dhidden'] = 4096*4
85
+ models['gpt3-7b']['nlayers'] = 32
86
+
87
+ models['gpt3-13b'] = {}
88
+ models['gpt3-13b']['seqlen'] = 2048
89
+ models['gpt3-13b']['dmodel'] = 5120
90
+ models['gpt3-13b']['dhidden'] = 5120*4
91
+ models['gpt3-13b']['nlayers'] = 40
92
+
93
+ models['gpt3-175b'] = {}
94
+ models['gpt3-175b']['seqlen'] = 2048
95
+ models['gpt3-175b']['dmodel'] = 12288
96
+ models['gpt3-175b']['dhidden'] = 12288*4
97
+ models['gpt3-175b']['nlayers'] = 96