prajwalsahu5 commited on
Commit
2fbd5c0
1 Parent(s): fd599dc

Upload 3 files

Browse files
Files changed (3) hide show
  1. GPT Model/model.py +310 -0
  2. GPT Model/trainer.py +109 -0
  3. GPT Model/utils.py +103 -0
GPT Model/model.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+
4
+ References:
5
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
6
+ https://github.com/openai/gpt-2/blob/master/src/model.py
7
+ 2) huggingface/transformers PyTorch implementation:
8
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
9
+ """
10
+
11
+ import math
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ from mingpt.utils import CfgNode as CN
18
+
19
+ # -----------------------------------------------------------------------------
20
+
21
+ class NewGELU(nn.Module):
22
+ """
23
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
24
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
25
+ """
26
+ def forward(self, x):
27
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+ """
31
+ A vanilla multi-head masked self-attention layer with a projection at the end.
32
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
33
+ explicit implementation here to show that there is nothing too scary here.
34
+ """
35
+
36
+ def __init__(self, config):
37
+ super().__init__()
38
+ assert config.n_embd % config.n_head == 0
39
+ # key, query, value projections for all heads, but in a batch
40
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
41
+ # output projection
42
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
43
+ # regularization
44
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
45
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
46
+ # causal mask to ensure that attention is only applied to the left in the input sequence
47
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
48
+ .view(1, 1, config.block_size, config.block_size))
49
+ self.n_head = config.n_head
50
+ self.n_embd = config.n_embd
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
63
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
64
+ att = F.softmax(att, dim=-1)
65
+ att = self.attn_dropout(att)
66
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
67
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
68
+
69
+ # output projection
70
+ y = self.resid_dropout(self.c_proj(y))
71
+ return y
72
+
73
+ class Block(nn.Module):
74
+ """ an unassuming Transformer block """
75
+
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.ln_1 = nn.LayerNorm(config.n_embd)
79
+ self.attn = CausalSelfAttention(config)
80
+ self.ln_2 = nn.LayerNorm(config.n_embd)
81
+ self.mlp = nn.ModuleDict(dict(
82
+ c_fc = nn.Linear(config.n_embd, 4 * config.n_embd),
83
+ c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
84
+ act = NewGELU(),
85
+ dropout = nn.Dropout(config.resid_pdrop),
86
+ ))
87
+ m = self.mlp
88
+ self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
89
+
90
+ def forward(self, x):
91
+ x = x + self.attn(self.ln_1(x))
92
+ x = x + self.mlpf(self.ln_2(x))
93
+ return x
94
+
95
+ class GPT(nn.Module):
96
+ """ GPT Language Model """
97
+
98
+ @staticmethod
99
+ def get_default_config():
100
+ C = CN()
101
+ # either model_type or (n_layer, n_head, n_embd) must be given in the config
102
+ C.model_type = 'gpt'
103
+ C.n_layer = None
104
+ C.n_head = None
105
+ C.n_embd = None
106
+ # these options must be filled in externally
107
+ C.vocab_size = None
108
+ C.block_size = None
109
+ # dropout hyperparameters
110
+ C.embd_pdrop = 0.1
111
+ C.resid_pdrop = 0.1
112
+ C.attn_pdrop = 0.1
113
+ return C
114
+
115
+ def __init__(self, config):
116
+ super().__init__()
117
+ assert config.vocab_size is not None
118
+ assert config.block_size is not None
119
+ self.block_size = config.block_size
120
+
121
+ type_given = config.model_type is not None
122
+ params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
123
+ assert type_given ^ params_given # exactly one of these (XOR)
124
+ if type_given:
125
+ # translate from model_type to detailed configuration
126
+ config.merge_from_dict({
127
+ # names follow the huggingface naming conventions
128
+ # GPT-1
129
+ 'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params
130
+ # GPT-2 configs
131
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
132
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
133
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
134
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
135
+ # Gophers
136
+ 'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512),
137
+ # (there are a number more...)
138
+ # I made these tiny models up
139
+ 'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192),
140
+ 'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128),
141
+ 'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48),
142
+ }[config.model_type])
143
+
144
+ self.transformer = nn.ModuleDict(dict(
145
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
146
+ wpe = nn.Embedding(config.block_size, config.n_embd),
147
+ drop = nn.Dropout(config.embd_pdrop),
148
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
149
+ ln_f = nn.LayerNorm(config.n_embd),
150
+ ))
151
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
152
+
153
+ # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
154
+ self.apply(self._init_weights)
155
+ for pn, p in self.named_parameters():
156
+ if pn.endswith('c_proj.weight'):
157
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
158
+
159
+ # report number of parameters (note we don't count the decoder parameters in lm_head)
160
+ n_params = sum(p.numel() for p in self.transformer.parameters())
161
+ print("number of parameters: %.2fM" % (n_params/1e6,))
162
+
163
+ def _init_weights(self, module):
164
+ if isinstance(module, nn.Linear):
165
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
166
+ if module.bias is not None:
167
+ torch.nn.init.zeros_(module.bias)
168
+ elif isinstance(module, nn.Embedding):
169
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
170
+ elif isinstance(module, nn.LayerNorm):
171
+ torch.nn.init.zeros_(module.bias)
172
+ torch.nn.init.ones_(module.weight)
173
+
174
+ @classmethod
175
+ def from_pretrained(cls, model_type):
176
+ """
177
+ Initialize a pretrained GPT model by copying over the weights
178
+ from a huggingface/transformers checkpoint.
179
+ """
180
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
181
+ from transformers import GPT2LMHeadModel
182
+
183
+ # create a from-scratch initialized minGPT model
184
+ config = cls.get_default_config()
185
+ config.model_type = model_type
186
+ config.vocab_size = 50257 # openai's model vocabulary
187
+ config.block_size = 1024 # openai's model block_size
188
+ model = GPT(config)
189
+ sd = model.state_dict()
190
+
191
+ # init a huggingface/transformers model
192
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
193
+ sd_hf = model_hf.state_dict()
194
+
195
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
196
+ keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
197
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
198
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
199
+ # this means that we have to transpose these weights when we import them
200
+ assert len(keys) == len(sd)
201
+ for k in keys:
202
+ if any(k.endswith(w) for w in transposed):
203
+ # special treatment for the Conv1D weights we need to transpose
204
+ assert sd_hf[k].shape[::-1] == sd[k].shape
205
+ with torch.no_grad():
206
+ sd[k].copy_(sd_hf[k].t())
207
+ else:
208
+ # vanilla copy over the other parameters
209
+ assert sd_hf[k].shape == sd[k].shape
210
+ with torch.no_grad():
211
+ sd[k].copy_(sd_hf[k])
212
+
213
+ return model
214
+
215
+ def configure_optimizers(self, train_config):
216
+ """
217
+ This long function is unfortunately doing something very simple and is being very defensive:
218
+ We are separating out all parameters of the model into two buckets: those that will experience
219
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
220
+ We are then returning the PyTorch optimizer object.
221
+ """
222
+
223
+ # separate out all parameters to those that will and won't experience regularizing weight decay
224
+ decay = set()
225
+ no_decay = set()
226
+ whitelist_weight_modules = (torch.nn.Linear, )
227
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
228
+ for mn, m in self.named_modules():
229
+ for pn, p in m.named_parameters():
230
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
231
+ # random note: because named_modules and named_parameters are recursive
232
+ # we will see the same tensors p many many times. but doing it this way
233
+ # allows us to know which parent module any tensor p belongs to...
234
+ if pn.endswith('bias'):
235
+ # all biases will not be decayed
236
+ no_decay.add(fpn)
237
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
238
+ # weights of whitelist modules will be weight decayed
239
+ decay.add(fpn)
240
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
241
+ # weights of blacklist modules will NOT be weight decayed
242
+ no_decay.add(fpn)
243
+
244
+ # validate that we considered every parameter
245
+ param_dict = {pn: p for pn, p in self.named_parameters()}
246
+ inter_params = decay & no_decay
247
+ union_params = decay | no_decay
248
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
249
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
250
+ % (str(param_dict.keys() - union_params), )
251
+
252
+ # create the pytorch optimizer object
253
+ optim_groups = [
254
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
255
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
256
+ ]
257
+ optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
258
+ return optimizer
259
+
260
+ def forward(self, idx, targets=None):
261
+ device = idx.device
262
+ b, t = idx.size()
263
+ assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
264
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
265
+
266
+ # forward the GPT model itself
267
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
268
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
269
+ x = self.transformer.drop(tok_emb + pos_emb)
270
+ for block in self.transformer.h:
271
+ x = block(x)
272
+ x = self.transformer.ln_f(x)
273
+ logits = self.lm_head(x)
274
+
275
+ # if we are given some desired targets also calculate the loss
276
+ loss = None
277
+ if targets is not None:
278
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
279
+
280
+ return logits, loss
281
+
282
+ @torch.no_grad()
283
+ def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
284
+ """
285
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
286
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
287
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
288
+ """
289
+ for _ in range(max_new_tokens):
290
+ # if the sequence context is growing too long we must crop it at block_size
291
+ idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
292
+ # forward the model to get the logits for the index in the sequence
293
+ logits, _ = self(idx_cond)
294
+ # pluck the logits at the final step and scale by desired temperature
295
+ logits = logits[:, -1, :] / temperature
296
+ # optionally crop the logits to only the top k options
297
+ if top_k is not None:
298
+ v, _ = torch.topk(logits, top_k)
299
+ logits[logits < v[:, [-1]]] = -float('Inf')
300
+ # apply softmax to convert logits to (normalized) probabilities
301
+ probs = F.softmax(logits, dim=-1)
302
+ # either sample from the distribution or take the most likely element
303
+ if do_sample:
304
+ idx_next = torch.multinomial(probs, num_samples=1)
305
+ else:
306
+ _, idx_next = torch.topk(probs, k=1, dim=-1)
307
+ # append sampled index to the running sequence and continue
308
+ idx = torch.cat((idx, idx_next), dim=1)
309
+
310
+ return idx
GPT Model/trainer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3
+ so nothing in this file really has anything to do with GPT specifically.
4
+ """
5
+
6
+ import time
7
+ from collections import defaultdict
8
+
9
+ import torch
10
+ from torch.utils.data.dataloader import DataLoader
11
+ from mingpt.utils import CfgNode as CN
12
+
13
+ class Trainer:
14
+
15
+ @staticmethod
16
+ def get_default_config():
17
+ C = CN()
18
+ # device to train on
19
+ C.device = 'auto'
20
+ # dataloder parameters
21
+ C.num_workers = 4
22
+ # optimizer parameters
23
+ C.max_iters = None
24
+ C.batch_size = 64
25
+ C.learning_rate = 3e-4
26
+ C.betas = (0.9, 0.95)
27
+ C.weight_decay = 0.1 # only applied on matmul weights
28
+ C.grad_norm_clip = 1.0
29
+ return C
30
+
31
+ def __init__(self, config, model, train_dataset):
32
+ self.config = config
33
+ self.model = model
34
+ self.optimizer = None
35
+ self.train_dataset = train_dataset
36
+ self.callbacks = defaultdict(list)
37
+
38
+ # determine the device we'll train on
39
+ if config.device == 'auto':
40
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
41
+ else:
42
+ self.device = config.device
43
+ self.model = self.model.to(self.device)
44
+ print("running on device", self.device)
45
+
46
+ # variables that will be assigned to trainer class later for logging and etc
47
+ self.iter_num = 0
48
+ self.iter_time = 0.0
49
+ self.iter_dt = 0.0
50
+
51
+ def add_callback(self, onevent: str, callback):
52
+ self.callbacks[onevent].append(callback)
53
+
54
+ def set_callback(self, onevent: str, callback):
55
+ self.callbacks[onevent] = [callback]
56
+
57
+ def trigger_callbacks(self, onevent: str):
58
+ for callback in self.callbacks.get(onevent, []):
59
+ callback(self)
60
+
61
+ def run(self):
62
+ model, config = self.model, self.config
63
+
64
+ # setup the optimizer
65
+ self.optimizer = model.configure_optimizers(config)
66
+
67
+ # setup the dataloader
68
+ train_loader = DataLoader(
69
+ self.train_dataset,
70
+ sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
71
+ shuffle=False,
72
+ pin_memory=True,
73
+ batch_size=config.batch_size,
74
+ num_workers=config.num_workers,
75
+ )
76
+
77
+ model.train()
78
+ self.iter_num = 0
79
+ self.iter_time = time.time()
80
+ data_iter = iter(train_loader)
81
+ while True:
82
+
83
+ # fetch the next batch (x, y) and re-init iterator if needed
84
+ try:
85
+ batch = next(data_iter)
86
+ except StopIteration:
87
+ data_iter = iter(train_loader)
88
+ batch = next(data_iter)
89
+ batch = [t.to(self.device) for t in batch]
90
+ x, y = batch
91
+
92
+ # forward the model
93
+ logits, self.loss = model(x, y)
94
+
95
+ # backprop and update the parameters
96
+ model.zero_grad(set_to_none=True)
97
+ self.loss.backward()
98
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
99
+ self.optimizer.step()
100
+
101
+ self.trigger_callbacks('on_batch_end')
102
+ self.iter_num += 1
103
+ tnow = time.time()
104
+ self.iter_dt = tnow - self.iter_time
105
+ self.iter_time = tnow
106
+
107
+ # termination conditions
108
+ if config.max_iters is not None and self.iter_num >= config.max_iters:
109
+ break
GPT Model/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import json
5
+ import random
6
+ from ast import literal_eval
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ # -----------------------------------------------------------------------------
12
+
13
+ def set_seed(seed):
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+ def setup_logging(config):
20
+ """ monotonous bookkeeping """
21
+ work_dir = config.system.work_dir
22
+ # create the work directory if it doesn't already exist
23
+ os.makedirs(work_dir, exist_ok=True)
24
+ # log the args (if any)
25
+ with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
26
+ f.write(' '.join(sys.argv))
27
+ # log the config itself
28
+ with open(os.path.join(work_dir, 'config.json'), 'w') as f:
29
+ f.write(json.dumps(config.to_dict(), indent=4))
30
+
31
+ class CfgNode:
32
+ """ a lightweight configuration class inspired by yacs """
33
+ # TODO: convert to subclass from a dict like in yacs?
34
+ # TODO: implement freezing to prevent shooting of own foot
35
+ # TODO: additional existence/override checks when reading/writing params?
36
+
37
+ def __init__(self, **kwargs):
38
+ self.__dict__.update(kwargs)
39
+
40
+ def __str__(self):
41
+ return self._str_helper(0)
42
+
43
+ def _str_helper(self, indent):
44
+ """ need to have a helper to support nested indentation for pretty printing """
45
+ parts = []
46
+ for k, v in self.__dict__.items():
47
+ if isinstance(v, CfgNode):
48
+ parts.append("%s:\n" % k)
49
+ parts.append(v._str_helper(indent + 1))
50
+ else:
51
+ parts.append("%s: %s\n" % (k, v))
52
+ parts = [' ' * (indent * 4) + p for p in parts]
53
+ return "".join(parts)
54
+
55
+ def to_dict(self):
56
+ """ return a dict representation of the config """
57
+ return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() }
58
+
59
+ def merge_from_dict(self, d):
60
+ self.__dict__.update(d)
61
+
62
+ def merge_from_args(self, args):
63
+ """
64
+ update the configuration from a list of strings that is expected
65
+ to come from the command line, i.e. sys.argv[1:].
66
+
67
+ The arguments are expected to be in the form of `--arg=value`, and
68
+ the arg can use . to denote nested sub-attributes. Example:
69
+
70
+ --model.n_layer=10 --trainer.batch_size=32
71
+ """
72
+ for arg in args:
73
+
74
+ keyval = arg.split('=')
75
+ assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
76
+ key, val = keyval # unpack
77
+
78
+ # first translate val into a python object
79
+ try:
80
+ val = literal_eval(val)
81
+ """
82
+ need some explanation here.
83
+ - if val is simply a string, literal_eval will throw a ValueError
84
+ - if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
85
+ """
86
+ except ValueError:
87
+ pass
88
+
89
+ # find the appropriate object to insert the attribute into
90
+ assert key[:2] == '--'
91
+ key = key[2:] # strip the '--'
92
+ keys = key.split('.')
93
+ obj = self
94
+ for k in keys[:-1]:
95
+ obj = getattr(obj, k)
96
+ leaf_key = keys[-1]
97
+
98
+ # ensure that this attribute exists
99
+ assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
100
+
101
+ # overwrite the attribute
102
+ print("command line overwriting config attribute %s with %s" % (key, val))
103
+ setattr(obj, leaf_key, val)