mridulk commited on
Commit
f72a645
·
1 Parent(s): 4580788

added main

Browse files
Files changed (1) hide show
  1. main.py +769 -0
main.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib, csv
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+ import torchvision
6
+ import pytorch_lightning as pl
7
+
8
+ from packaging import version
9
+ from omegaconf import OmegaConf
10
+ from torch.utils.data import random_split, DataLoader, Dataset, Subset
11
+ from functools import partial
12
+ from PIL import Image
13
+
14
+ from pytorch_lightning import seed_everything
15
+ from pytorch_lightning.trainer import Trainer
16
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
17
+ from pytorch_lightning.utilities.distributed import rank_zero_only
18
+ from pytorch_lightning.utilities import rank_zero_info
19
+ import ldm.data.constants as CONSTANTS
20
+
21
+ from ldm.data.base import Txt2ImgIterableBaseDataset
22
+ from ldm.util import instantiate_from_config
23
+
24
+ def get_monitor(target):
25
+ return "val" + CONSTANTS.RECLOSS
26
+
27
+ def get_parser(**parser_kwargs):
28
+ def str2bool(v):
29
+ if isinstance(v, bool):
30
+ return v
31
+ if v.lower() in ("yes", "true", "t", "y", "1"):
32
+ return True
33
+ elif v.lower() in ("no", "false", "f", "n", "0"):
34
+ return False
35
+ else:
36
+ raise argparse.ArgumentTypeError("Boolean value expected.")
37
+
38
+ parser = argparse.ArgumentParser(**parser_kwargs)
39
+ parser.add_argument(
40
+ "-n",
41
+ "--name",
42
+ type=str,
43
+ const=True,
44
+ default="",
45
+ nargs="?",
46
+ help="postfix for logdir",
47
+ )
48
+ parser.add_argument(
49
+ "-r",
50
+ "--resume",
51
+ type=str,
52
+ const=True,
53
+ default="",
54
+ nargs="?",
55
+ help="resume from logdir or checkpoint in logdir",
56
+ )
57
+ parser.add_argument(
58
+ "-b",
59
+ "--base",
60
+ nargs="*",
61
+ metavar="base_config.yaml",
62
+ help="paths to base configs. Loaded from left-to-right. "
63
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
64
+ default=list(),
65
+ )
66
+ parser.add_argument(
67
+ "-t",
68
+ "--train",
69
+ type=str2bool,
70
+ const=True,
71
+ default=False,
72
+ nargs="?",
73
+ help="train",
74
+ )
75
+ parser.add_argument(
76
+ "--no-test",
77
+ type=str2bool,
78
+ const=True,
79
+ default=False,
80
+ nargs="?",
81
+ help="disable test",
82
+ )
83
+ parser.add_argument(
84
+ "-p",
85
+ "--project",
86
+ help="name of new or path to existing project"
87
+ )
88
+ parser.add_argument(
89
+ "-d",
90
+ "--debug",
91
+ type=str2bool,
92
+ nargs="?",
93
+ const=True,
94
+ default=False,
95
+ help="enable post-mortem debugging",
96
+ )
97
+ parser.add_argument(
98
+ "-s",
99
+ "--seed",
100
+ type=int,
101
+ default=23,
102
+ help="seed for seed_everything",
103
+ )
104
+ parser.add_argument(
105
+ "-f",
106
+ "--postfix",
107
+ type=str,
108
+ default="",
109
+ help="post-postfix for default name",
110
+ )
111
+ parser.add_argument(
112
+ "-l",
113
+ "--logdir",
114
+ type=str,
115
+ default="logs",
116
+ help="directory for logging dat shit",
117
+ )
118
+ parser.add_argument(
119
+ "--scale_lr",
120
+ type=str2bool,
121
+ nargs="?",
122
+ const=True,
123
+ default=True,
124
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
125
+ )
126
+ return parser
127
+
128
+
129
+ def nondefault_trainer_args(opt):
130
+ parser = argparse.ArgumentParser()
131
+ parser = Trainer.add_argparse_args(parser)
132
+ args = parser.parse_args([])
133
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
134
+
135
+
136
+ class WrappedDataset(Dataset):
137
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
138
+
139
+ def __init__(self, dataset):
140
+ self.data = dataset
141
+
142
+ def __len__(self):
143
+ return len(self.data)
144
+
145
+ def __getitem__(self, idx):
146
+ return self.data[idx]
147
+
148
+
149
+ def worker_init_fn(_):
150
+ worker_info = torch.utils.data.get_worker_info()
151
+
152
+ dataset = worker_info.dataset
153
+ worker_id = worker_info.id
154
+
155
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
156
+ split_size = dataset.num_records // worker_info.num_workers
157
+ # reset num_records to the true number to retain reliable length information
158
+ dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
159
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
160
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
161
+ else:
162
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
163
+
164
+
165
+ class DataModuleFromConfig(pl.LightningDataModule):
166
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
167
+ wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
168
+ shuffle_val_dataloader=False):
169
+ super().__init__()
170
+ self.batch_size = batch_size
171
+ self.dataset_configs = dict()
172
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
173
+ self.use_worker_init_fn = use_worker_init_fn
174
+ if train is not None:
175
+ self.dataset_configs["train"] = train
176
+ self.train_dataloader = self._train_dataloader
177
+ if validation is not None:
178
+ self.dataset_configs["validation"] = validation
179
+ # self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
180
+ self.val_dataloader = self._val_dataloader
181
+ if test is not None:
182
+ self.dataset_configs["test"] = test
183
+ # self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
184
+ self.test_dataloader = self._test_dataloader
185
+ if predict is not None:
186
+ self.dataset_configs["predict"] = predict
187
+ self.predict_dataloader = self._predict_dataloader
188
+ self.wrap = wrap
189
+
190
+ def prepare_data(self):
191
+ for data_cfg in self.dataset_configs.values():
192
+ instantiate_from_config(data_cfg)
193
+
194
+ def setup(self, stage=None):
195
+ self.datasets = dict(
196
+ (k, instantiate_from_config(self.dataset_configs[k]))
197
+ for k in self.dataset_configs)
198
+ if self.wrap:
199
+ for k in self.datasets:
200
+ self.datasets[k] = WrappedDataset(self.datasets[k])
201
+
202
+ def _train_dataloader(self):
203
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
204
+ if is_iterable_dataset or self.use_worker_init_fn:
205
+ init_fn = worker_init_fn
206
+ else:
207
+ init_fn = None
208
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
209
+ num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
210
+ worker_init_fn=init_fn)
211
+
212
+ def _val_dataloader(self, shuffle=False):
213
+ if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
214
+ init_fn = worker_init_fn
215
+ else:
216
+ init_fn = None
217
+ return DataLoader(self.datasets["validation"],
218
+ batch_size=self.batch_size,
219
+ num_workers=self.num_workers,
220
+ worker_init_fn=init_fn,
221
+ shuffle=shuffle)
222
+
223
+ def _test_dataloader(self, shuffle=False):
224
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
225
+ if is_iterable_dataset or self.use_worker_init_fn:
226
+ init_fn = worker_init_fn
227
+ else:
228
+ init_fn = None
229
+
230
+ # do not shuffle dataloader for iterable dataset
231
+ shuffle = shuffle and (not is_iterable_dataset)
232
+
233
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
234
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
235
+
236
+ def _predict_dataloader(self, shuffle=False):
237
+ if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
238
+ init_fn = worker_init_fn
239
+ else:
240
+ init_fn = None
241
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
242
+ num_workers=self.num_workers, worker_init_fn=init_fn)
243
+
244
+
245
+ class SetupCallback(Callback):
246
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
247
+ super().__init__()
248
+ self.resume = resume
249
+ self.now = now
250
+ self.logdir = logdir
251
+ self.ckptdir = ckptdir
252
+ self.cfgdir = cfgdir
253
+ self.config = config
254
+ self.lightning_config = lightning_config
255
+
256
+ def on_keyboard_interrupt(self, trainer, pl_module):
257
+ if trainer.global_rank == 0:
258
+ print("Summoning checkpoint.")
259
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
260
+ trainer.save_checkpoint(ckpt_path)
261
+
262
+ def on_pretrain_routine_start(self, trainer, pl_module):
263
+ if trainer.global_rank == 0:
264
+ # Create logdirs and save configs
265
+ os.makedirs(self.logdir, exist_ok=True)
266
+ os.makedirs(self.ckptdir, exist_ok=True)
267
+ os.makedirs(self.cfgdir, exist_ok=True)
268
+
269
+ if "callbacks" in self.lightning_config:
270
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
271
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
272
+ print("Project config")
273
+ # print(OmegaConf.to_yaml(self.config))
274
+ print(self.config.pretty())
275
+ OmegaConf.save(self.config,
276
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
277
+
278
+ print("Lightning config")
279
+ # print(OmegaConf.to_yaml(self.lightning_config))
280
+ print(self.lightning_config.pretty())
281
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
282
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
283
+
284
+ else:
285
+ # ModelCheckpoint callback created log directory --- remove it
286
+ if not self.resume and os.path.exists(self.logdir):
287
+ dst, name = os.path.split(self.logdir)
288
+ dst = os.path.join(dst, "child_runs", name)
289
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
290
+ try:
291
+ os.rename(self.logdir, dst)
292
+ except FileNotFoundError:
293
+ pass
294
+
295
+
296
+ class ImageLogger(Callback):
297
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
298
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
299
+ log_images_kwargs=None):
300
+ super().__init__()
301
+ self.rescale = rescale
302
+ self.batch_freq = batch_frequency
303
+ self.max_images = max_images
304
+ self.logger_log_images = {
305
+ pl.loggers.TestTubeLogger: self._testtube,
306
+ }
307
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
308
+ if not increase_log_steps:
309
+ self.log_steps = [self.batch_freq]
310
+ self.clamp = clamp
311
+ self.disabled = disabled
312
+ self.log_on_batch_idx = log_on_batch_idx
313
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
314
+ self.log_first_step = log_first_step
315
+
316
+ @rank_zero_only
317
+ def _testtube(self, pl_module, images, batch_idx, split):
318
+ for k in images:
319
+ grid = torchvision.utils.make_grid(images[k])
320
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
321
+
322
+ tag = f"{split}/{k}"
323
+ pl_module.logger.experiment.add_image(
324
+ tag, grid,
325
+ global_step=pl_module.global_step)
326
+
327
+ @rank_zero_only
328
+ def log_local(self, save_dir, split, images,
329
+ global_step, current_epoch, batch_idx):
330
+ root = os.path.join(save_dir, "images", split)
331
+ for k in images:
332
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
333
+ if self.rescale:
334
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
335
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
336
+ grid = grid.numpy()
337
+ grid = (grid * 255).astype(np.uint8)
338
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
339
+ k,
340
+ global_step,
341
+ current_epoch,
342
+ batch_idx)
343
+ path = os.path.join(root, filename)
344
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
345
+ Image.fromarray(grid).save(path)
346
+
347
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
348
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
349
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
350
+ hasattr(pl_module, "log_images") and
351
+ callable(pl_module.log_images) and
352
+ self.max_images > 0):
353
+ logger = type(pl_module.logger)
354
+
355
+ is_train = pl_module.training
356
+ if is_train:
357
+ pl_module.eval()
358
+
359
+ with torch.no_grad():
360
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
361
+
362
+ for k in images:
363
+ N = min(images[k].shape[0], self.max_images)
364
+ images[k] = images[k][:N]
365
+ if isinstance(images[k], torch.Tensor):
366
+ images[k] = images[k].detach().cpu()
367
+ if self.clamp:
368
+ images[k] = torch.clamp(images[k], -1., 1.)
369
+
370
+ self.log_local(pl_module.logger.save_dir, split, images,
371
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
372
+
373
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
374
+ logger_log_images(pl_module, images, pl_module.global_step, split)
375
+
376
+ if is_train:
377
+ pl_module.train()
378
+
379
+ def check_frequency(self, check_idx):
380
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
381
+ check_idx > 0 or self.log_first_step):
382
+ try:
383
+ self.log_steps.pop(0)
384
+ except IndexError as e:
385
+ print(e)
386
+ pass
387
+ return True
388
+ return False
389
+
390
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
391
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
392
+ self.log_img(pl_module, batch, batch_idx, split="train")
393
+
394
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
395
+ if not self.disabled and pl_module.global_step > 0:
396
+ self.log_img(pl_module, batch, batch_idx, split="val")
397
+ if hasattr(pl_module, 'calibrate_grad_norm'):
398
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
399
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
400
+
401
+
402
+ class CUDACallback(Callback):
403
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
404
+ def on_train_epoch_start(self, trainer, pl_module):
405
+ # Reset the memory use counter
406
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
407
+ torch.cuda.synchronize(trainer.root_gpu)
408
+ self.start_time = time.time()
409
+
410
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
411
+ torch.cuda.synchronize(trainer.root_gpu)
412
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
413
+ epoch_time = time.time() - self.start_time
414
+
415
+ try:
416
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
417
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
418
+
419
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
420
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
421
+ except AttributeError:
422
+ pass
423
+
424
+
425
+ if __name__ == "__main__":
426
+ # custom parser to specify config files, train, test and debug mode,
427
+ # postfix, resume.
428
+ # `--key value` arguments are interpreted as arguments to the trainer.
429
+ # `nested.key=value` arguments are interpreted as config parameters.
430
+ # configs are merged from left-to-right followed by command line parameters.
431
+
432
+ # model:
433
+ # base_learning_rate: float
434
+ # target: path to lightning module
435
+ # params:
436
+ # key: value
437
+ # data:
438
+ # target: main.DataModuleFromConfig
439
+ # params:
440
+ # batch_size: int
441
+ # wrap: bool
442
+ # train:
443
+ # target: path to train dataset
444
+ # params:
445
+ # key: value
446
+ # validation:
447
+ # target: path to validation dataset
448
+ # params:
449
+ # key: value
450
+ # test:
451
+ # target: path to test dataset
452
+ # params:
453
+ # key: value
454
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
455
+ # trainer:
456
+ # additional arguments to trainer
457
+ # logger:
458
+ # logger to instantiate
459
+ # modelcheckpoint:
460
+ # modelcheckpoint to instantiate
461
+ # callbacks:
462
+ # callback1:
463
+ # target: importpath
464
+ # params:
465
+ # key: value
466
+
467
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
468
+
469
+ # add cwd for convenience and to make classes in this file available when
470
+ # running as `python main.py`
471
+ # (in particular `main.DataModuleFromConfig`)
472
+ sys.path.append(os.getcwd())
473
+
474
+ parser = get_parser()
475
+ parser = Trainer.add_argparse_args(parser)
476
+
477
+ opt, unknown = parser.parse_known_args()
478
+ if opt.name and opt.resume:
479
+ raise ValueError(
480
+ "-n/--name and -r/--resume cannot be specified both."
481
+ "If you want to resume training in a new log folder, "
482
+ "use -n/--name in combination with --resume_from_checkpoint"
483
+ )
484
+ if opt.resume:
485
+ if not os.path.exists(opt.resume):
486
+ raise ValueError("Cannot find {}".format(opt.resume))
487
+ if os.path.isfile(opt.resume):
488
+ paths = opt.resume.split("/")
489
+ # idx = len(paths)-paths[::-1].index("logs")+1
490
+ # logdir = "/".join(paths[:idx])
491
+ logdir = "/".join(paths[:-2])
492
+ ckpt = opt.resume
493
+ else:
494
+ assert os.path.isdir(opt.resume), opt.resume
495
+ logdir = opt.resume.rstrip("/")
496
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
497
+
498
+ opt.resume_from_checkpoint = ckpt
499
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
500
+ opt.base = base_configs + opt.base
501
+ _tmp = logdir.split("/")
502
+ nowname = _tmp[-1]
503
+ else:
504
+ if opt.name:
505
+ name = "_" + opt.name
506
+ elif opt.base:
507
+ cfg_fname = os.path.split(opt.base[0])[-1]
508
+ cfg_name = os.path.splitext(cfg_fname)[0]
509
+ name = "_" + cfg_name
510
+ else:
511
+ name = ""
512
+ nowname = now + name + opt.postfix
513
+ logdir = os.path.join(opt.logdir, nowname)
514
+
515
+ ckptdir = os.path.join(logdir, "checkpoints")
516
+ cfgdir = os.path.join(logdir, "configs")
517
+ seed_everything(opt.seed)
518
+
519
+ try:
520
+ # init and save configs
521
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
522
+ cli = OmegaConf.from_dotlist(unknown)
523
+ config = OmegaConf.merge(*configs, cli)
524
+ lightning_config = config.pop("lightning", OmegaConf.create())
525
+ # merge trainer cli with config
526
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
527
+ # default to ddp
528
+ trainer_config["accelerator"] = "ddp"
529
+ for k in nondefault_trainer_args(opt):
530
+ trainer_config[k] = getattr(opt, k)
531
+ if not "gpus" in trainer_config:
532
+ del trainer_config["accelerator"]
533
+ cpu = True
534
+ else:
535
+ gpuinfo = trainer_config["gpus"]
536
+ print(f"Running on GPUs {gpuinfo}")
537
+ cpu = False
538
+ trainer_opt = argparse.Namespace(**trainer_config)
539
+ lightning_config.trainer = trainer_config
540
+
541
+ # model
542
+ model = instantiate_from_config(config.model)
543
+
544
+ # trainer and callbacks
545
+ trainer_kwargs = dict()
546
+
547
+ # default logger configs
548
+ # NOTE wandb < 0.10.0 interferes with shutdown
549
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
550
+ # debugging (wrongly sized pudb ui)
551
+ # thus prefer testtube for now
552
+ default_logger_cfgs = {
553
+ "wandb": {
554
+ "target": "pytorch_lightning.loggers.WandbLogger",
555
+ "params": {
556
+ "name": nowname,
557
+ "save_dir": logdir,
558
+ "offline": opt.debug,
559
+ "id": nowname,
560
+ "project": config.model.project
561
+ }
562
+ },
563
+ "testtube": {
564
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
565
+ "params": {
566
+ "name": "testtube",
567
+ "save_dir": logdir,
568
+ }
569
+ },
570
+ }
571
+
572
+ default_logger_cfg = default_logger_cfgs["wandb"] # "testtube" "wandb"
573
+ logger_cfg = lightning_config.logger or OmegaConf.create()
574
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
575
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
576
+
577
+ # Wandb configs
578
+ if rank_zero_only.rank == 0:
579
+ trainer_kwargs["logger"].experiment.config["lr"]=config.model.base_learning_rate
580
+ trainer_kwargs["logger"].experiment.config["batch_size"]=config.data.params.batch_size
581
+ trainer_kwargs["logger"].watch(model, log_freq=100)
582
+
583
+ # # default_logger_cfg = default_logger_cfgs["testtube"]
584
+ # if "logger" in lightning_config:
585
+ # logger_cfg = lightning_config.logger
586
+ # else:
587
+ # logger_cfg = OmegaConf.create()
588
+ # logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
589
+ # trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
590
+
591
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
592
+ # specify which metric is used to determine best models
593
+ default_modelckpt_cfg = {
594
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
595
+ "params": {
596
+ "dirpath": ckptdir,
597
+ "filename": "{epoch:06}",
598
+ "verbose": True,
599
+ "monitor": get_monitor(config.model.target),
600
+ "save_top_k": 1,
601
+ "mode": "min",
602
+ "period": 3,
603
+ "save_last": True,
604
+ }
605
+ }
606
+ if hasattr(model, "monitor"):
607
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
608
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
609
+ default_modelckpt_cfg["params"]["save_top_k"] = 5
610
+
611
+ if "modelcheckpoint" in lightning_config:
612
+ modelckpt_cfg = lightning_config.modelcheckpoint
613
+ else:
614
+ modelckpt_cfg = OmegaConf.create()
615
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
616
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
617
+ if version.parse(pl.__version__) < version.parse('1.4.0'):
618
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
619
+
620
+ # add callback which sets up log directory
621
+ default_callbacks_cfg = {
622
+ "setup_callback": {
623
+ "target": "main.SetupCallback",
624
+ "params": {
625
+ "resume": opt.resume,
626
+ "now": now,
627
+ "logdir": logdir,
628
+ "ckptdir": ckptdir,
629
+ "cfgdir": cfgdir,
630
+ "config": config,
631
+ "lightning_config": lightning_config,
632
+ }
633
+ },
634
+ "image_logger": {
635
+ "target": "main.ImageLogger",
636
+ "params": {
637
+ "batch_frequency": 750,
638
+ "max_images": 4,
639
+ "clamp": True
640
+ }
641
+ },
642
+ "learning_rate_logger": {
643
+ "target": "main.LearningRateMonitor",
644
+ "params": {
645
+ "logging_interval": "step",
646
+ # "log_momentum": True
647
+ }
648
+ },
649
+ "cuda_callback": {
650
+ "target": "main.CUDACallback"
651
+ },
652
+ }
653
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
654
+ default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
655
+
656
+ if "callbacks" in lightning_config:
657
+ callbacks_cfg = lightning_config.callbacks
658
+ else:
659
+ callbacks_cfg = OmegaConf.create()
660
+
661
+ if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
662
+ print(
663
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
664
+ default_metrics_over_trainsteps_ckpt_dict = {
665
+ 'metrics_over_trainsteps_checkpoint':
666
+ {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
667
+ 'params': {
668
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
669
+ "filename": "{epoch:06}-{step:09}",
670
+ "verbose": True,
671
+ 'save_top_k': -1,
672
+ 'every_n_train_steps': 10000,
673
+ 'save_weights_only': True
674
+ }
675
+ }
676
+ }
677
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
678
+
679
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
680
+ if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
681
+ callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
682
+ elif 'ignore_keys_callback' in callbacks_cfg:
683
+ del callbacks_cfg['ignore_keys_callback']
684
+
685
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
686
+
687
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
688
+ trainer.logdir = logdir ###
689
+
690
+ # data
691
+ data = instantiate_from_config(config.data)
692
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
693
+ # calling these ourselves should not be necessary but it is.
694
+ # lightning still takes care of proper multiprocessing though
695
+ data.prepare_data()
696
+ data.setup()
697
+ print("#### Data #####")
698
+ for k in data.datasets:
699
+ print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
700
+
701
+ # configure learning rate
702
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
703
+ if not cpu:
704
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
705
+ else:
706
+ ngpu = 1
707
+ if 'accumulate_grad_batches' in lightning_config.trainer:
708
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
709
+ else:
710
+ accumulate_grad_batches = 1
711
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
712
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
713
+ if opt.scale_lr:
714
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
715
+ print(
716
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
717
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
718
+ else:
719
+ model.learning_rate = base_lr
720
+ print("++++ NOT USING LR SCALING ++++")
721
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
722
+
723
+
724
+ # allow checkpointing via USR1
725
+ def melk(*args, **kwargs):
726
+ # run all checkpoint hooks
727
+ if trainer.global_rank == 0:
728
+ print("Summoning checkpoint.")
729
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
730
+ trainer.save_checkpoint(ckpt_path)
731
+
732
+
733
+ def divein(*args, **kwargs):
734
+ if trainer.global_rank == 0:
735
+ import pudb;
736
+ pudb.set_trace()
737
+
738
+
739
+ import signal
740
+
741
+ signal.signal(signal.SIGUSR1, melk)
742
+ signal.signal(signal.SIGUSR2, divein)
743
+
744
+ # run
745
+ if opt.train:
746
+ try:
747
+ trainer.fit(model, data)
748
+ except Exception:
749
+ melk()
750
+ raise
751
+ if not opt.no_test and not trainer.interrupted:
752
+ trainer.test(model, data)
753
+ except Exception:
754
+ if opt.debug and trainer.global_rank == 0:
755
+ try:
756
+ import pudb as debugger
757
+ except ImportError:
758
+ import pdb as debugger
759
+ debugger.post_mortem()
760
+ raise
761
+ finally:
762
+ # move newly created debug project to debug_runs
763
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
764
+ dst, name = os.path.split(logdir)
765
+ dst = os.path.join(dst, "debug_runs", name)
766
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
767
+ os.rename(logdir, dst)
768
+ if trainer.global_rank == 0:
769
+ print(trainer.profiler.summary())