Gkason commited on
Commit
7264121
1 Parent(s): 84d35a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +741 -0
app.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib, csv
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+ import torchvision
6
+ import pytorch_lightning as pl
7
+
8
+ from packaging import version
9
+ from omegaconf import OmegaConf
10
+ from torch.utils.data import random_split, DataLoader, Dataset, Subset
11
+ from functools import partial
12
+ from PIL import Image
13
+
14
+ from pytorch_lightning import seed_everything
15
+ from pytorch_lightning.trainer import Trainer
16
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
17
+ from pytorch_lightning.utilities.distributed import rank_zero_only
18
+ from pytorch_lightning.utilities import rank_zero_info
19
+
20
+ from ldm.data.base import Txt2ImgIterableBaseDataset
21
+ from ldm.util import instantiate_from_config
22
+
23
+
24
+ def get_parser(**parser_kwargs):
25
+ def str2bool(v):
26
+ if isinstance(v, bool):
27
+ return v
28
+ if v.lower() in ("yes", "true", "t", "y", "1"):
29
+ return True
30
+ elif v.lower() in ("no", "false", "f", "n", "0"):
31
+ return False
32
+ else:
33
+ raise argparse.ArgumentTypeError("Boolean value expected.")
34
+
35
+ parser = argparse.ArgumentParser(**parser_kwargs)
36
+ parser.add_argument(
37
+ "-n",
38
+ "--name",
39
+ type=str,
40
+ const=True,
41
+ default="",
42
+ nargs="?",
43
+ help="postfix for logdir",
44
+ )
45
+ parser.add_argument(
46
+ "-r",
47
+ "--resume",
48
+ type=str,
49
+ const=True,
50
+ default="",
51
+ nargs="?",
52
+ help="resume from logdir or checkpoint in logdir",
53
+ )
54
+ parser.add_argument(
55
+ "-b",
56
+ "--base",
57
+ nargs="*",
58
+ metavar="base_config.yaml",
59
+ help="paths to base configs. Loaded from left-to-right. "
60
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
61
+ default=list(),
62
+ )
63
+ parser.add_argument(
64
+ "-t",
65
+ "--train",
66
+ type=str2bool,
67
+ const=True,
68
+ default=False,
69
+ nargs="?",
70
+ help="train",
71
+ )
72
+ parser.add_argument(
73
+ "--no-test",
74
+ type=str2bool,
75
+ const=True,
76
+ default=False,
77
+ nargs="?",
78
+ help="disable test",
79
+ )
80
+ parser.add_argument(
81
+ "-p",
82
+ "--project",
83
+ help="name of new or path to existing project"
84
+ )
85
+ parser.add_argument(
86
+ "-d",
87
+ "--debug",
88
+ type=str2bool,
89
+ nargs="?",
90
+ const=True,
91
+ default=False,
92
+ help="enable post-mortem debugging",
93
+ )
94
+ parser.add_argument(
95
+ "-s",
96
+ "--seed",
97
+ type=int,
98
+ default=23,
99
+ help="seed for seed_everything",
100
+ )
101
+ parser.add_argument(
102
+ "-f",
103
+ "--postfix",
104
+ type=str,
105
+ default="",
106
+ help="post-postfix for default name",
107
+ )
108
+ parser.add_argument(
109
+ "-l",
110
+ "--logdir",
111
+ type=str,
112
+ default="logs",
113
+ help="directory for logging dat shit",
114
+ )
115
+ parser.add_argument(
116
+ "--scale_lr",
117
+ type=str2bool,
118
+ nargs="?",
119
+ const=True,
120
+ default=True,
121
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
122
+ )
123
+ return parser
124
+
125
+
126
+ def nondefault_trainer_args(opt):
127
+ parser = argparse.ArgumentParser()
128
+ parser = Trainer.add_argparse_args(parser)
129
+ args = parser.parse_args([])
130
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
131
+
132
+
133
+ class WrappedDataset(Dataset):
134
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
135
+
136
+ def __init__(self, dataset):
137
+ self.data = dataset
138
+
139
+ def __len__(self):
140
+ return len(self.data)
141
+
142
+ def __getitem__(self, idx):
143
+ return self.data[idx]
144
+
145
+
146
+ def worker_init_fn(_):
147
+ worker_info = torch.utils.data.get_worker_info()
148
+
149
+ dataset = worker_info.dataset
150
+ worker_id = worker_info.id
151
+
152
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
153
+ split_size = dataset.num_records // worker_info.num_workers
154
+ # reset num_records to the true number to retain reliable length information
155
+ dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
156
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
157
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
158
+ else:
159
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
160
+
161
+
162
+ class DataModuleFromConfig(pl.LightningDataModule):
163
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
164
+ wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
165
+ shuffle_val_dataloader=False):
166
+ super().__init__()
167
+ self.batch_size = batch_size
168
+ self.dataset_configs = dict()
169
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
170
+ self.use_worker_init_fn = use_worker_init_fn
171
+ if train is not None:
172
+ self.dataset_configs["train"] = train
173
+ self.train_dataloader = self._train_dataloader
174
+ if validation is not None:
175
+ self.dataset_configs["validation"] = validation
176
+ self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
177
+ if test is not None:
178
+ self.dataset_configs["test"] = test
179
+ self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
180
+ if predict is not None:
181
+ self.dataset_configs["predict"] = predict
182
+ self.predict_dataloader = self._predict_dataloader
183
+ self.wrap = wrap
184
+
185
+ def prepare_data(self):
186
+ for data_cfg in self.dataset_configs.values():
187
+ instantiate_from_config(data_cfg)
188
+
189
+ def setup(self, stage=None):
190
+ self.datasets = dict(
191
+ (k, instantiate_from_config(self.dataset_configs[k]))
192
+ for k in self.dataset_configs)
193
+ if self.wrap:
194
+ for k in self.datasets:
195
+ self.datasets[k] = WrappedDataset(self.datasets[k])
196
+
197
+ def _train_dataloader(self):
198
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
199
+ if is_iterable_dataset or self.use_worker_init_fn:
200
+ init_fn = worker_init_fn
201
+ else:
202
+ init_fn = None
203
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
204
+ num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
205
+ worker_init_fn=init_fn)
206
+
207
+ def _val_dataloader(self, shuffle=False):
208
+ if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
209
+ init_fn = worker_init_fn
210
+ else:
211
+ init_fn = None
212
+ return DataLoader(self.datasets["validation"],
213
+ batch_size=self.batch_size,
214
+ num_workers=self.num_workers,
215
+ worker_init_fn=init_fn,
216
+ shuffle=shuffle)
217
+
218
+ def _test_dataloader(self, shuffle=False):
219
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
220
+ if is_iterable_dataset or self.use_worker_init_fn:
221
+ init_fn = worker_init_fn
222
+ else:
223
+ init_fn = None
224
+
225
+ # do not shuffle dataloader for iterable dataset
226
+ shuffle = shuffle and (not is_iterable_dataset)
227
+
228
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
229
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
230
+
231
+ def _predict_dataloader(self, shuffle=False):
232
+ if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
233
+ init_fn = worker_init_fn
234
+ else:
235
+ init_fn = None
236
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
237
+ num_workers=self.num_workers, worker_init_fn=init_fn)
238
+
239
+
240
+ class SetupCallback(Callback):
241
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
242
+ super().__init__()
243
+ self.resume = resume
244
+ self.now = now
245
+ self.logdir = logdir
246
+ self.ckptdir = ckptdir
247
+ self.cfgdir = cfgdir
248
+ self.config = config
249
+ self.lightning_config = lightning_config
250
+
251
+ def on_keyboard_interrupt(self, trainer, pl_module):
252
+ if trainer.global_rank == 0:
253
+ print("Summoning checkpoint.")
254
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
255
+ trainer.save_checkpoint(ckpt_path)
256
+
257
+ def on_pretrain_routine_start(self, trainer, pl_module):
258
+ if trainer.global_rank == 0:
259
+ # Create logdirs and save configs
260
+ os.makedirs(self.logdir, exist_ok=True)
261
+ os.makedirs(self.ckptdir, exist_ok=True)
262
+ os.makedirs(self.cfgdir, exist_ok=True)
263
+
264
+ if "callbacks" in self.lightning_config:
265
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
266
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
267
+ print("Project config")
268
+ print(OmegaConf.to_yaml(self.config))
269
+ OmegaConf.save(self.config,
270
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
271
+
272
+ print("Lightning config")
273
+ print(OmegaConf.to_yaml(self.lightning_config))
274
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
275
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
276
+
277
+ else:
278
+ # ModelCheckpoint callback created log directory --- remove it
279
+ if not self.resume and os.path.exists(self.logdir):
280
+ dst, name = os.path.split(self.logdir)
281
+ dst = os.path.join(dst, "child_runs", name)
282
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
283
+ try:
284
+ os.rename(self.logdir, dst)
285
+ except FileNotFoundError:
286
+ pass
287
+
288
+
289
+ class ImageLogger(Callback):
290
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
291
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
292
+ log_images_kwargs=None):
293
+ super().__init__()
294
+ self.rescale = rescale
295
+ self.batch_freq = batch_frequency
296
+ self.max_images = max_images
297
+ self.logger_log_images = {
298
+ pl.loggers.TestTubeLogger: self._testtube,
299
+ }
300
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
301
+ if not increase_log_steps:
302
+ self.log_steps = [self.batch_freq]
303
+ self.clamp = clamp
304
+ self.disabled = disabled
305
+ self.log_on_batch_idx = log_on_batch_idx
306
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
307
+ self.log_first_step = log_first_step
308
+
309
+ @rank_zero_only
310
+ def _testtube(self, pl_module, images, batch_idx, split):
311
+ for k in images:
312
+ grid = torchvision.utils.make_grid(images[k])
313
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
314
+
315
+ tag = f"{split}/{k}"
316
+ pl_module.logger.experiment.add_image(
317
+ tag, grid,
318
+ global_step=pl_module.global_step)
319
+
320
+ @rank_zero_only
321
+ def log_local(self, save_dir, split, images,
322
+ global_step, current_epoch, batch_idx):
323
+ root = os.path.join(save_dir, "images", split)
324
+ for k in images:
325
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
326
+ if self.rescale:
327
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
328
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
329
+ grid = grid.numpy()
330
+ grid = (grid * 255).astype(np.uint8)
331
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
332
+ k,
333
+ global_step,
334
+ current_epoch,
335
+ batch_idx)
336
+ path = os.path.join(root, filename)
337
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
338
+ Image.fromarray(grid).save(path)
339
+
340
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
341
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
342
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
343
+ hasattr(pl_module, "log_images") and
344
+ callable(pl_module.log_images) and
345
+ self.max_images > 0):
346
+ logger = type(pl_module.logger)
347
+
348
+ is_train = pl_module.training
349
+ if is_train:
350
+ pl_module.eval()
351
+
352
+ with torch.no_grad():
353
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
354
+
355
+ for k in images:
356
+ N = min(images[k].shape[0], self.max_images)
357
+ images[k] = images[k][:N]
358
+ if isinstance(images[k], torch.Tensor):
359
+ images[k] = images[k].detach().cpu()
360
+ if self.clamp:
361
+ images[k] = torch.clamp(images[k], -1., 1.)
362
+
363
+ self.log_local(pl_module.logger.save_dir, split, images,
364
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
365
+
366
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
367
+ logger_log_images(pl_module, images, pl_module.global_step, split)
368
+
369
+ if is_train:
370
+ pl_module.train()
371
+
372
+ def check_frequency(self, check_idx):
373
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
374
+ check_idx > 0 or self.log_first_step):
375
+ try:
376
+ self.log_steps.pop(0)
377
+ except IndexError as e:
378
+ print(e)
379
+ pass
380
+ return True
381
+ return False
382
+
383
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
384
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
385
+ self.log_img(pl_module, batch, batch_idx, split="train")
386
+
387
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
388
+ if not self.disabled and pl_module.global_step > 0:
389
+ self.log_img(pl_module, batch, batch_idx, split="val")
390
+ if hasattr(pl_module, 'calibrate_grad_norm'):
391
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
392
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
393
+
394
+
395
+ class CUDACallback(Callback):
396
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
397
+ def on_train_epoch_start(self, trainer, pl_module):
398
+ # Reset the memory use counter
399
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
400
+ torch.cuda.synchronize(trainer.root_gpu)
401
+ self.start_time = time.time()
402
+
403
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
404
+ torch.cuda.synchronize(trainer.root_gpu)
405
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
406
+ epoch_time = time.time() - self.start_time
407
+
408
+ try:
409
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
410
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
411
+
412
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
413
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
414
+ except AttributeError:
415
+ pass
416
+
417
+
418
+ if __name__ == "__main__":
419
+ # custom parser to specify config files, train, test and debug mode,
420
+ # postfix, resume.
421
+ # `--key value` arguments are interpreted as arguments to the trainer.
422
+ # `nested.key=value` arguments are interpreted as config parameters.
423
+ # configs are merged from left-to-right followed by command line parameters.
424
+
425
+ # model:
426
+ # base_learning_rate: float
427
+ # target: path to lightning module
428
+ # params:
429
+ # key: value
430
+ # data:
431
+ # target: main.DataModuleFromConfig
432
+ # params:
433
+ # batch_size: int
434
+ # wrap: bool
435
+ # train:
436
+ # target: path to train dataset
437
+ # params:
438
+ # key: value
439
+ # validation:
440
+ # target: path to validation dataset
441
+ # params:
442
+ # key: value
443
+ # test:
444
+ # target: path to test dataset
445
+ # params:
446
+ # key: value
447
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
448
+ # trainer:
449
+ # additional arguments to trainer
450
+ # logger:
451
+ # logger to instantiate
452
+ # modelcheckpoint:
453
+ # modelcheckpoint to instantiate
454
+ # callbacks:
455
+ # callback1:
456
+ # target: importpath
457
+ # params:
458
+ # key: value
459
+
460
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
461
+
462
+ # add cwd for convenience and to make classes in this file available when
463
+ # running as `python main.py`
464
+ # (in particular `main.DataModuleFromConfig`)
465
+ sys.path.append(os.getcwd())
466
+
467
+ parser = get_parser()
468
+ parser = Trainer.add_argparse_args(parser)
469
+
470
+ opt, unknown = parser.parse_known_args()
471
+ if opt.name and opt.resume:
472
+ raise ValueError(
473
+ "-n/--name and -r/--resume cannot be specified both."
474
+ "If you want to resume training in a new log folder, "
475
+ "use -n/--name in combination with --resume_from_checkpoint"
476
+ )
477
+ if opt.resume:
478
+ if not os.path.exists(opt.resume):
479
+ raise ValueError("Cannot find {}".format(opt.resume))
480
+ if os.path.isfile(opt.resume):
481
+ paths = opt.resume.split("/")
482
+ # idx = len(paths)-paths[::-1].index("logs")+1
483
+ # logdir = "/".join(paths[:idx])
484
+ logdir = "/".join(paths[:-2])
485
+ ckpt = opt.resume
486
+ else:
487
+ assert os.path.isdir(opt.resume), opt.resume
488
+ logdir = opt.resume.rstrip("/")
489
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
490
+
491
+ opt.resume_from_checkpoint = ckpt
492
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
493
+ opt.base = base_configs + opt.base
494
+ _tmp = logdir.split("/")
495
+ nowname = _tmp[-1]
496
+ else:
497
+ if opt.name:
498
+ name = "_" + opt.name
499
+ elif opt.base:
500
+ cfg_fname = os.path.split(opt.base[0])[-1]
501
+ cfg_name = os.path.splitext(cfg_fname)[0]
502
+ name = "_" + cfg_name
503
+ else:
504
+ name = ""
505
+ nowname = now + name + opt.postfix
506
+ logdir = os.path.join(opt.logdir, nowname)
507
+
508
+ ckptdir = os.path.join(logdir, "checkpoints")
509
+ cfgdir = os.path.join(logdir, "configs")
510
+ seed_everything(opt.seed)
511
+
512
+ try:
513
+ # init and save configs
514
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
515
+ cli = OmegaConf.from_dotlist(unknown)
516
+ config = OmegaConf.merge(*configs, cli)
517
+ lightning_config = config.pop("lightning", OmegaConf.create())
518
+ # merge trainer cli with config
519
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
520
+ # default to ddp
521
+ trainer_config["accelerator"] = "ddp"
522
+ for k in nondefault_trainer_args(opt):
523
+ trainer_config[k] = getattr(opt, k)
524
+ if not "gpus" in trainer_config:
525
+ del trainer_config["accelerator"]
526
+ cpu = True
527
+ else:
528
+ gpuinfo = trainer_config["gpus"]
529
+ print(f"Running on GPUs {gpuinfo}")
530
+ cpu = False
531
+ trainer_opt = argparse.Namespace(**trainer_config)
532
+ lightning_config.trainer = trainer_config
533
+
534
+ # model
535
+ model = instantiate_from_config(config.model)
536
+
537
+ # trainer and callbacks
538
+ trainer_kwargs = dict()
539
+
540
+ # default logger configs
541
+ default_logger_cfgs = {
542
+ "wandb": {
543
+ "target": "pytorch_lightning.loggers.WandbLogger",
544
+ "params": {
545
+ "name": nowname,
546
+ "save_dir": logdir,
547
+ "offline": opt.debug,
548
+ "id": nowname,
549
+ }
550
+ },
551
+ "testtube": {
552
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
553
+ "params": {
554
+ "name": "testtube",
555
+ "save_dir": logdir,
556
+ }
557
+ },
558
+ }
559
+ default_logger_cfg = default_logger_cfgs["testtube"]
560
+ if "logger" in lightning_config:
561
+ logger_cfg = lightning_config.logger
562
+ else:
563
+ logger_cfg = OmegaConf.create()
564
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
565
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
566
+
567
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
568
+ # specify which metric is used to determine best models
569
+ default_modelckpt_cfg = {
570
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
571
+ "params": {
572
+ "dirpath": ckptdir,
573
+ "filename": "{epoch:06}",
574
+ "verbose": True,
575
+ "save_last": True,
576
+ }
577
+ }
578
+ if hasattr(model, "monitor"):
579
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
580
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
581
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
582
+
583
+ if "modelcheckpoint" in lightning_config:
584
+ modelckpt_cfg = lightning_config.modelcheckpoint
585
+ else:
586
+ modelckpt_cfg = OmegaConf.create()
587
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
588
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
589
+ if version.parse(pl.__version__) < version.parse('1.4.0'):
590
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
591
+
592
+ # add callback which sets up log directory
593
+ default_callbacks_cfg = {
594
+ "setup_callback": {
595
+ "target": "main.SetupCallback",
596
+ "params": {
597
+ "resume": opt.resume,
598
+ "now": now,
599
+ "logdir": logdir,
600
+ "ckptdir": ckptdir,
601
+ "cfgdir": cfgdir,
602
+ "config": config,
603
+ "lightning_config": lightning_config,
604
+ }
605
+ },
606
+ "image_logger": {
607
+ "target": "main.ImageLogger",
608
+ "params": {
609
+ "batch_frequency": 750,
610
+ "max_images": 4,
611
+ "clamp": True
612
+ }
613
+ },
614
+ "learning_rate_logger": {
615
+ "target": "main.LearningRateMonitor",
616
+ "params": {
617
+ "logging_interval": "step",
618
+ # "log_momentum": True
619
+ }
620
+ },
621
+ "cuda_callback": {
622
+ "target": "main.CUDACallback"
623
+ },
624
+ }
625
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
626
+ default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
627
+
628
+ if "callbacks" in lightning_config:
629
+ callbacks_cfg = lightning_config.callbacks
630
+ else:
631
+ callbacks_cfg = OmegaConf.create()
632
+
633
+ if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
634
+ print(
635
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
636
+ default_metrics_over_trainsteps_ckpt_dict = {
637
+ 'metrics_over_trainsteps_checkpoint':
638
+ {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
639
+ 'params': {
640
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
641
+ "filename": "{epoch:06}-{step:09}",
642
+ "verbose": True,
643
+ 'save_top_k': -1,
644
+ 'every_n_train_steps': 10000,
645
+ 'save_weights_only': True
646
+ }
647
+ }
648
+ }
649
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
650
+
651
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
652
+ if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
653
+ callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
654
+ elif 'ignore_keys_callback' in callbacks_cfg:
655
+ del callbacks_cfg['ignore_keys_callback']
656
+
657
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
658
+
659
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
660
+ trainer.logdir = logdir ###
661
+
662
+ # data
663
+ data = instantiate_from_config(config.data)
664
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
665
+ # calling these ourselves should not be necessary but it is.
666
+ # lightning still takes care of proper multiprocessing though
667
+ data.prepare_data()
668
+ data.setup()
669
+ print("#### Data #####")
670
+ for k in data.datasets:
671
+ print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
672
+
673
+ # configure learning rate
674
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
675
+ if not cpu:
676
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
677
+ else:
678
+ ngpu = 1
679
+ if 'accumulate_grad_batches' in lightning_config.trainer:
680
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
681
+ else:
682
+ accumulate_grad_batches = 1
683
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
684
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
685
+ if opt.scale_lr:
686
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
687
+ print(
688
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
689
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
690
+ else:
691
+ model.learning_rate = base_lr
692
+ print("++++ NOT USING LR SCALING ++++")
693
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
694
+
695
+
696
+ # allow checkpointing via USR1
697
+ def melk(*args, **kwargs):
698
+ # run all checkpoint hooks
699
+ if trainer.global_rank == 0:
700
+ print("Summoning checkpoint.")
701
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
702
+ trainer.save_checkpoint(ckpt_path)
703
+
704
+
705
+ def divein(*args, **kwargs):
706
+ if trainer.global_rank == 0:
707
+ import pudb;
708
+ pudb.set_trace()
709
+
710
+
711
+ import signal
712
+
713
+ signal.signal(signal.SIGUSR1, melk)
714
+ signal.signal(signal.SIGUSR2, divein)
715
+
716
+ # run
717
+ if opt.train:
718
+ try:
719
+ trainer.fit(model, data)
720
+ except Exception:
721
+ melk()
722
+ raise
723
+ if not opt.no_test and not trainer.interrupted:
724
+ trainer.test(model, data)
725
+ except Exception:
726
+ if opt.debug and trainer.global_rank == 0:
727
+ try:
728
+ import pudb as debugger
729
+ except ImportError:
730
+ import pdb as debugger
731
+ debugger.post_mortem()
732
+ raise
733
+ finally:
734
+ # move newly created debug project to debug_runs
735
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
736
+ dst, name = os.path.split(logdir)
737
+ dst = os.path.join(dst, "debug_runs", name)
738
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
739
+ os.rename(logdir, dst)
740
+ if trainer.global_rank == 0:
741
+ print(trainer.profiler.summary())