justheuristic commited on
Commit
2133880
1 Parent(s): 52aac07

sharing groups

Browse files
Files changed (3) hide show
  1. app.py +5 -3
  2. mem_calc.py +5 -5
  3. models.py +13 -7
app.py CHANGED
@@ -27,15 +27,17 @@ share_params = col2.checkbox("Share parameters", value=False)
27
 
28
  with st.expander("More options"):
29
  batch_size = int(st.number_input('Microbatch size (sequences)', min_value=1, step=1, value=1, format="%i"))
30
- seq_len = int(st.number_input('Sequence length (max. tokens)', min_value=1, step=1, value=1024, format="%i"))
31
  precisions_names = ('Full', 'Mixed ("O1")', 'Pure 16-bit')
32
  precisions_values = ('O0', 'O1', 'O3')
 
 
33
  precision = st.selectbox('Precision', precisions_names, index=1)
34
 
35
  args = mem_calc.parse_args(f"""
36
  --model {model} --optimizer {optimizers_values[optimizers_names.index(optimizer)]}
37
- {'--checkpoint' if checkpoint else ''} {'--offload' if offload else ''} {'--albert' if share_params else ''}
38
- --fp16-level {precisions_values[precisions_names.index(precision)]} --bsz {batch_size} --seqlen {seq_len}
 
39
  """.split())
40
 
41
 
27
 
28
  with st.expander("More options"):
29
  batch_size = int(st.number_input('Microbatch size (sequences)', min_value=1, step=1, value=1, format="%i"))
 
30
  precisions_names = ('Full', 'Mixed ("O1")', 'Pure 16-bit')
31
  precisions_values = ('O0', 'O1', 'O3')
32
+ sharing_groups = int(st.number_input('Shared parameter groups (used if Share parameters is checked)',
33
+ min_value=1, step=1, value=1, format="%i"))
34
  precision = st.selectbox('Precision', precisions_names, index=1)
35
 
36
  args = mem_calc.parse_args(f"""
37
  --model {model} --optimizer {optimizers_values[optimizers_names.index(optimizer)]}
38
+ {'--checkpoint' if checkpoint else ''} {'--offload' if offload else ''}
39
+ --fp16-level {precisions_values[precisions_names.index(precision)]} --bsz {batch_size}
40
+ {f'--shared_groups {sharing_groups}' if share_params else ''}
41
  """.split())
42
 
43
 
mem_calc.py CHANGED
@@ -24,7 +24,7 @@ def vocab(bsz, seqlen, dmodel, vocab_dim):
24
 
25
 
26
  def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
27
- checkpoint=False, albert=False):
28
  if dhid is None: dhid = 4*dmodel
29
  model = 0
30
  grad = 0
@@ -33,8 +33,8 @@ def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
33
  model += m
34
  grad += g
35
 
36
- if albert:
37
- model = model / nlayers
38
 
39
  m, g = vocab(bsz, seqlen, dmodel, vocab_type)
40
  model += m
@@ -128,7 +128,7 @@ def parse_args(args=None):
128
  parser.add_argument('--ngpus', type=int, default=1, help='The number of gpus. Default: 1')
129
  parser.add_argument('--zero', type=int, default=0,
130
  help='The ZeRO level (1 optimizer, 2 optimizer+weights, 3 everything. Default: 1')
131
- parser.add_argument('--albert', action='store_true', help='Use parameter sharing.')
132
  parser.add_argument('--checkpoint', action='store_true', help='Use gradient checkpointing.')
133
 
134
  return parser.parse_args(args)
@@ -143,7 +143,7 @@ def calculate_memory(args):
143
  if getattr(args, key, None) is None:
144
  setattr(args, key, value)
145
 
146
- model, grad = transformer(args.bsz, args.seqlen, args.dmodel, args.nlayers, args.vocab_size, args.dhid, args.checkpoint, args.albert)
147
  parameters = model
148
 
149
  if args.optimizer == 'adam':
24
 
25
 
26
  def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
27
+ checkpoint=False, shared_groups=None):
28
  if dhid is None: dhid = 4*dmodel
29
  model = 0
30
  grad = 0
33
  model += m
34
  grad += g
35
 
36
+ if shared_groups is not None:
37
+ model = model / nlayers * shared_groups
38
 
39
  m, g = vocab(bsz, seqlen, dmodel, vocab_type)
40
  model += m
128
  parser.add_argument('--ngpus', type=int, default=1, help='The number of gpus. Default: 1')
129
  parser.add_argument('--zero', type=int, default=0,
130
  help='The ZeRO level (1 optimizer, 2 optimizer+weights, 3 everything. Default: 1')
131
+ parser.add_argument('--shared_groups', type=int, default=None, help='Number of shared layer groups (as in ALBERT). Defaults to no sharing.')
132
  parser.add_argument('--checkpoint', action='store_true', help='Use gradient checkpointing.')
133
 
134
  return parser.parse_args(args)
143
  if getattr(args, key, None) is None:
144
  setattr(args, key, value)
145
 
146
+ model, grad = transformer(args.bsz, args.seqlen, args.dmodel, args.nlayers, args.vocab_size, args.dhid, args.checkpoint, args.shared_groups)
147
  parameters = model
148
 
149
  if args.optimizer == 'adam':
models.py CHANGED
@@ -56,13 +56,6 @@ models['gpt2-xl']['dhid'] = 1600*4
56
  models['gpt2-xl']['nlayers'] = 48
57
  models['gpt2-xl']['vocab_size'] = 50257
58
 
59
- models['gpt-j-6b'] = {}
60
- models['gpt-j-6b']['seqlen'] = 2048
61
- models['gpt-j-6b']['dmodel'] = 4096
62
- models['gpt-j-6b']['dhid'] = 4096 * 4
63
- models['gpt-j-6b']['nlayers'] = 28
64
- models['gpt-j-6b']['vocab_size'] = 50400
65
-
66
  models['gpt3-s'] = {}
67
  models['gpt3-s']['seqlen'] = 2048
68
  models['gpt3-s']['dmodel'] = 768
@@ -118,3 +111,16 @@ models['gpt3-175b']['dmodel'] = 12288
118
  models['gpt3-175b']['dhid'] = 12288*4
119
  models['gpt3-175b']['nlayers'] = 96
120
  models['gpt3-175b']['vocab_size'] = 50257 # from public reimplementations
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  models['gpt2-xl']['nlayers'] = 48
57
  models['gpt2-xl']['vocab_size'] = 50257
58
 
 
 
 
 
 
 
 
59
  models['gpt3-s'] = {}
60
  models['gpt3-s']['seqlen'] = 2048
61
  models['gpt3-s']['dmodel'] = 768
111
  models['gpt3-175b']['dhid'] = 12288*4
112
  models['gpt3-175b']['nlayers'] = 96
113
  models['gpt3-175b']['vocab_size'] = 50257 # from public reimplementations
114
+
115
+ models['gpt-j-6b'] = {}
116
+ models['gpt-j-6b']['seqlen'] = 2048
117
+ models['gpt-j-6b']['dmodel'] = 4096
118
+ models['gpt-j-6b']['dhid'] = 4096 * 4
119
+ models['gpt-j-6b']['nlayers'] = 28
120
+ models['gpt-j-6b']['vocab_size'] = 50400
121
+
122
+ models['dalle-12b'] = {}
123
+ models['dalle-12b']['seqlen'] = 1024 + 256
124
+ models['dalle-12b']['dmodel'] = 62 * 64
125
+ models['dalle-12b']['nlayers'] = 64
126
+ models['dalle-12b']['vocab_size'] = 8192 + 16384