""" Main component: the trainer handles everything: * initializations * training * saving """ import inspect import warnings from copy import deepcopy from pathlib import Path from time import time import numpy as np from comet_ml import ExistingExperiment, Experiment warnings.simplefilter("ignore", UserWarning) import torch import torch.nn as nn from addict import Dict from torch import autograd, sigmoid, softmax from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm from climategan.data import get_all_loaders from climategan.discriminator import OmniDiscriminator, create_discriminator from climategan.eval_metrics import accuracy, mIOU from climategan.fid import compute_val_fid from climategan.fire import add_fire from climategan.generator import OmniGenerator, create_generator from climategan.logger import Logger from climategan.losses import get_losses from climategan.optim import get_optimizer from climategan.transforms import DiffTransforms from climategan.tutils import ( divide_pred, get_num_params, get_WGAN_gradient, lrgb2srgb, normalize, print_num_parameters, shuffle_batch_tuple, srgb2lrgb, vgg_preprocess, zero_grad, ) from climategan.utils import ( comet_kwargs, div_dict, find_target_size, flatten_opts, get_display_indices, get_existing_comet_id, get_latest_opts, merge, resolve, sum_dict, Timer, ) try: import torch_xla.core.xla_model as xm # type: ignore except ImportError: pass class Trainer: """Main trainer class""" def __init__(self, opts, comet_exp=None, verbose=0, device=None): """Trainer class to gather various model training procedures such as training evaluating saving and logging init: * creates an addict.Dict logger * creates logger.exp as a comet_exp experiment if `comet` arg is True * sets the device (1 GPU or CPU) Args: opts (addict.Dict): options to configure the trainer, the data, the models comet (bool, optional): whether to log the trainer with comet.ml. Defaults to False. verbose (int, optional): printing level to debug. Defaults to 0. """ super().__init__() self.opts = opts self.verbose = verbose self.logger = Logger(self) self.losses = None self.G = self.D = None self.real_val_fid_stats = None self.use_pl4m = False self.is_setup = False self.loaders = self.all_loaders = None self.exp = None self.current_mode = "train" self.diff_transforms = None self.kitti_pretrain = self.opts.train.kitti.pretrain self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks) self.lr_names = {} self.base_display_images = {} self.kitty_display_images = {} self.domain_labels = {"s": 0, "r": 1} self.device = device or torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" ) if isinstance(comet_exp, Experiment): self.exp = comet_exp if self.opts.train.amp: optimizers = [ self.opts.gen.opt.optimizer.lower(), self.opts.dis.opt.optimizer.lower(), ] if "extraadam" in optimizers: raise ValueError( "AMP does not work with ExtraAdam ({})".format(optimizers) ) self.grad_scaler_d = GradScaler() self.grad_scaler_g = GradScaler() # ------------------------------- # ----- Legacy Overwrites ----- # ------------------------------- if ( self.opts.gen.s.depth_feat_fusion is True or self.opts.gen.s.depth_dada_fusion is True ): self.opts.gen.s.use_dada = True @torch.no_grad() def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"): """ Paints a batch of images (or a single image with a batch dim of 1). If masks are not provided, they are inferred from the masker. Resolution can either be the train-time resolution or the closest multiple of 2 ** spade_n_up Operations performed without gradient If resolution == "approx" then the output image has the shape: (dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width] eg: (1000, 1300) => (896, 1280) for spade_n_up = 7 If resolution == "exact" then the output image has the same shape: we first process in "approx" mode then upsample bilinear If resolution == "basic" image output shape is the train-time's (typically 640x640) If resolution == "upsample" image is inferred as "basic" and then upsampled to original size Args: image_batch (torch.Tensor): 4D batch of images to flood mask_batch (torch.Tensor, optional): Masks for the images. Defaults to None (infer with Masker). resolution (str, optional): "approx", "exact" or False Returns: torch.Tensor: N x C x H x W where H and W depend on `resolution` """ assert resolution in {"approx", "exact", "basic", "upsample"} previous_mode = self.current_mode if previous_mode == "train": self.eval_mode() if mask_batch is None: mask_batch = self.G.mask(x=image_batch) else: assert len(image_batch) == len(mask_batch) assert image_batch.shape[-2:] == mask_batch.shape[-2:] if resolution not in {"approx", "exact"}: painted = self.G.paint(mask_batch, image_batch) if resolution == "upsample": painted = nn.functional.interpolate( painted, size=image_batch.shape[-2:], mode="bilinear" ) else: # save latent shape zh = self.G.painter.z_h zw = self.G.painter.z_w # adapt latent shape to approximately keep the resolution self.G.painter.z_h = ( image_batch.shape[-2] // 2 ** self.opts.gen.p.spade_n_up ) self.G.painter.z_w = ( image_batch.shape[-1] // 2 ** self.opts.gen.p.spade_n_up ) painted = self.G.paint(mask_batch, image_batch) self.G.painter.z_h = zh self.G.painter.z_w = zw if resolution == "exact": painted = nn.functional.interpolate( painted, size=image_batch.shape[-2:], mode="bilinear" ) if previous_mode == "train": self.train_mode() return painted def _p(self, *args, **kwargs): """ verbose-dependant print util """ if self.verbose > 0: print(*args, **kwargs) @torch.no_grad() def infer_all( self, x, numpy=True, stores={}, bin_value=-1, half=False, xla=False, cloudy=False, auto_resize_640=False, ignore_event=set(), ): """ Create a dictionnary of events from a numpy or tensor, single or batch image data. stores is a dictionnary of times for the Timer class. bin_value is used to binarize (or not) flood masks """ assert self.is_setup assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}" # convert numpy to tensor if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device) # add batch dimension if len(x.shape) == 3: x.unsqueeze_(0) # permute channels as second dimension if x.shape[1] != 3: assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}" x = x.permute(0, 3, 1, 2) # send to device if x.device != self.device: x = x.to(self.device) # interpolate to standard input size if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640): x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear") if half: x = x.half() # adjust painter's latent vector self.G.painter.set_latent_shape(x.shape, True) with Timer(store=stores.get("all events", [])): # encode with Timer(store=stores.get("encode", [])): z = self.G.encode(x) if xla: xm.mark_step() # predict from masker with Timer(store=stores.get("depth", [])): depth, z_depth = self.G.decoders["d"](z) if xla: xm.mark_step() with Timer(store=stores.get("segmentation", [])): segmentation = self.G.decoders["s"](z, z_depth) if xla: xm.mark_step() with Timer(store=stores.get("mask", [])): cond = self.G.make_m_cond(depth, segmentation, x) mask = self.G.mask(z=z, cond=cond, z_depth=z_depth) if xla: xm.mark_step() # apply events if "wildfire" not in ignore_event: with Timer(store=stores.get("wildfire", [])): wildfire = self.compute_fire(x, seg_preds=segmentation) if "smog" not in ignore_event: with Timer(store=stores.get("smog", [])): smog = self.compute_smog(x, d=depth, s=segmentation) if "flood" not in ignore_event: with Timer(store=stores.get("flood", [])): flood = self.compute_flood( x, m=mask, s=segmentation, cloudy=cloudy, bin_value=bin_value, ) if xla: xm.mark_step() if numpy: with Timer(store=stores.get("numpy", [])): # normalize to 0-1 flood = normalize(flood).cpu() smog = normalize(smog).cpu() wildfire = normalize(wildfire).cpu() # convert to numpy flood = flood.permute(0, 2, 3, 1).numpy() smog = smog.permute(0, 2, 3, 1).numpy() wildfire = wildfire.permute(0, 2, 3, 1).numpy() # convert to 0-255 uint8 flood = (flood * 255).astype(np.uint8) smog = (smog * 255).astype(np.uint8) wildfire = (wildfire * 255).astype(np.uint8) return {"flood": flood, "wildfire": wildfire, "smog": smog} @classmethod def resume_from_path( cls, path, overrides={}, setup=True, inference=False, new_exp=False, device=None, verbose=1, ): """ Resume and optionally setup a trainer from a specific path, using the latest opts and checkpoint. Requires path to contain opts.yaml (or increased), url.txt (or increased) and checkpoints/ Args: path (str | pathlib.Path): Trainer to resume overrides (dict, optional): Override loaded opts with those. Defaults to {}. setup (bool, optional): Wether or not to setup the trainer before returning it. Defaults to True. inference (bool, optional): Setup should be done in inference mode or not. Defaults to False. new_exp (bool, optional): Re-use existing comet exp in path or create a new one? Defaults to False. device (torch.device, optional): Device to use Returns: climategan.Trainer: Loaded and resumed trainer """ p = resolve(path) print(p) assert p.exists() c = p / "checkpoints" assert c.exists() and c.is_dir() opts = get_latest_opts(p) opts = Dict(merge(overrides, opts)) opts.train.resume = True if new_exp is None: exp = None elif new_exp is True: exp = Experiment(project_name="climategan", **comet_kwargs) exp.log_asset_folder( str(resolve(Path(__file__)).parent), recursive=True, log_file_name=True, ) exp.log_parameters(flatten_opts(opts)) else: comet_id = get_existing_comet_id(p) exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs) trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose) if setup: trainer.setup(inference=inference) return trainer def save(self): save_dir = Path(self.opts.output_path) / Path("checkpoints") save_dir.mkdir(exist_ok=True) save_path = save_dir / "latest_ckpt.pth" # Construct relevant state dicts / optims: # Save at least G save_dict = { "epoch": self.logger.epoch, "G": self.G.state_dict(), "g_opt": self.g_opt.state_dict(), "step": self.logger.global_step, } if self.D is not None and get_num_params(self.D) > 0: save_dict["D"] = self.D.state_dict() save_dict["d_opt"] = self.d_opt.state_dict() if ( self.logger.epoch >= self.opts.train.min_save_epoch and self.logger.epoch % self.opts.train.save_n_epochs == 0 ): torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth") torch.save(save_dict, save_path) def resume(self, inference=False): tpu = "xla" in str(self.device) if tpu: print("Resuming on TPU:", self.device) m_path = Path(self.opts.load_paths.m) p_path = Path(self.opts.load_paths.p) pm_path = Path(self.opts.load_paths.pm) output_path = Path(self.opts.output_path) map_loc = self.device if not tpu else "cpu" if "m" in self.opts.tasks and "p" in self.opts.tasks: # ---------------------------------------- # ----- Masker and Painter Loading ----- # ---------------------------------------- # want to resume a pm model but no path was provided: # resume a single pm model from output_path if all([str(p) == "none" for p in [m_path, p_path, pm_path]]): checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" print("Resuming P+M model from", str(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location=map_loc) # want to resume a pm model with a pm_path provided: # resume a single pm model from load_paths.pm # depending on whether a dir or a file is specified elif str(pm_path) != "none": assert pm_path.exists() if pm_path.is_dir(): checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth" else: assert pm_path.suffix == ".pth" checkpoint_path = pm_path print("Resuming P+M model from", str(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location=map_loc) # want to resume a pm model, pm_path not provided: # m_path and p_path must be provided as dirs or pth files elif m_path != p_path: assert m_path.exists() assert p_path.exists() if m_path.is_dir(): m_path = m_path / "checkpoints/latest_ckpt.pth" if p_path.is_dir(): p_path = p_path / "checkpoints/latest_ckpt.pth" assert m_path.suffix == ".pth" assert p_path.suffix == ".pth" print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}") m_checkpoint = torch.load(m_path, map_location=map_loc) p_checkpoint = torch.load(p_path, map_location=map_loc) checkpoint = merge(m_checkpoint, p_checkpoint) else: raise ValueError( "Cannot resume a P+M model with provided load_paths:\n{}".format( self.opts.load_paths ) ) else: # ---------------------------------- # ----- Single Model Loading ----- # ---------------------------------- # cannot specify both paths if str(m_path) != "none" and str(p_path) != "none": raise ValueError( "Opts tasks are {} but received 2 values for the load_paths".format( self.opts.tasks ) ) # specified m elif str(m_path) != "none": print(m_path) assert m_path.exists() assert "m" in self.opts.tasks model = "M" if m_path.is_dir(): m_path = m_path / "checkpoints/latest_ckpt.pth" checkpoint_path = m_path # specified m elif str(p_path) != "none": assert p_path.exists() assert "p" in self.opts.tasks model = "P" if p_path.is_dir(): p_path = p_path / "checkpoints/latest_ckpt.pth" checkpoint_path = p_path # specified neither p nor m: resume from output_path else: model = "P" if "p" in self.opts.tasks else "M" checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" print(f"Resuming {model} model from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=map_loc) # On TPUs must send the data to the xla device as it cannot be mapped # there directly from torch.load if tpu: checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device) # ----------------------- # ----- Restore G ----- # ----------------------- if inference: incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False) if incompatible_keys.missing_keys: print("WARNING: Missing keys in self.G.load_state_dict, keeping inits") print(incompatible_keys.missing_keys) if incompatible_keys.unexpected_keys: print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict") print(incompatible_keys.unexpected_keys) else: self.G.load_state_dict(checkpoint["G"]) if inference: # only G is needed to infer print("Done loading checkpoints.") return self.g_opt.load_state_dict(checkpoint["g_opt"]) # ------------------------------ # ----- Resume scheduler ----- # ------------------------------ # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 for _ in range(self.logger.epoch + 1): self.update_learning_rates() # ----------------------- # ----- Restore D ----- # ----------------------- if self.D is not None and get_num_params(self.D) > 0: self.D.load_state_dict(checkpoint["D"]) self.d_opt.load_state_dict(checkpoint["d_opt"]) # --------------------------- # ----- Resore logger ----- # --------------------------- self.logger.epoch = checkpoint["epoch"] self.logger.global_step = checkpoint["step"] self.exp.log_text( "Resuming from epoch {} & step {}".format( checkpoint["epoch"], checkpoint["step"] ) ) # Round step to even number for extraGradient if self.logger.global_step % 2 != 0: self.logger.global_step += 1 def eval_mode(self): """ Set trainer's models in eval mode """ if self.G is not None: self.G.eval() if self.D is not None: self.D.eval() self.current_mode = "eval" def train_mode(self): """ Set trainer's models in train mode """ if self.G is not None: self.G.train() if self.D is not None: self.D.train() self.current_mode = "train" def assert_z_matches_x(self, x, z): assert x.shape[0] == ( z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0] ), "x-> {}, z->{}".format( x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape ) def batch_to_device(self, b): """sends the data in b to self.device Args: b (dict): the batch dictionnay Returns: dict: the batch dictionnary with its "data" field sent to self.device """ for task, tensor in b["data"].items(): b["data"][task] = tensor.to(self.device) return b def sample_painter_z(self, batch_size): return self.G.sample_painter_z(batch_size, self.device) @property def train_loaders(self): """Get a zip of all training loaders Returns: generator: zip generator yielding tuples: (batch_rf, batch_rn, batch_sf, batch_sn) """ return zip(*list(self.loaders["train"].values())) @property def val_loaders(self): """Get a zip of all validation loaders Returns: generator: zip generator yielding tuples: (batch_rf, batch_rn, batch_sf, batch_sn) """ return zip(*list(self.loaders["val"].values())) def compute_latent_shape(self): """Compute the latent shape, i.e. the Encoder's output shape, from a batch. Raises: ValueError: If no loader, the latent_shape cannot be inferred Returns: tuple: (c, h, w) """ x = None for mode in self.all_loaders: for domain in self.all_loaders.loaders[mode]: x = ( self.all_loaders[mode][domain] .dataset[0]["data"]["x"] .to(self.device) ) break if x is not None: break if x is None: raise ValueError("No batch found to compute_latent_shape") x = x.unsqueeze(0) z = self.G.encode(x) return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:] def g_opt_step(self): """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation step every other step """ if "extra" in self.opts.gen.opt.optimizer.lower() and ( self.logger.global_step % 2 == 0 ): self.g_opt.extrapolation() else: self.g_opt.step() def d_opt_step(self): """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation step every other step """ if "extra" in self.opts.dis.opt.optimizer.lower() and ( self.logger.global_step % 2 == 0 ): self.d_opt.extrapolation() else: self.d_opt.step() def update_learning_rates(self): if self.g_scheduler is not None: self.g_scheduler.step() if self.d_scheduler is not None: self.d_scheduler.step() def setup(self, inference=False): """Prepare the trainer before it can be used to train the models: * initialize G and D * creates 2 optimizers """ self.logger.global_step = 0 start_time = time() self.logger.time.start_time = start_time verbose = self.verbose if not inference: self.all_loaders = get_all_loaders(self.opts) # ----------------------- # ----- Generator ----- # ----------------------- __t = time() print("Creating generator...") self.G: OmniGenerator = create_generator( self.opts, device=self.device, no_init=inference, verbose=verbose ) self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter() if self.has_painter: self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True) print(f"Generator OK in {time() - __t:.1f}s.") if inference: # Inference mode: no more than a Generator needed print("Inference mode: no Discriminator, no optimizers") print_num_parameters(self) self.switch_data(to="base") if self.opts.train.resume: self.resume(True) self.eval_mode() print("Trainer is in evaluation mode.") print("Setup done.") self.is_setup = True return # --------------------------- # ----- Discriminator ----- # --------------------------- self.D: OmniDiscriminator = create_discriminator( self.opts, self.device, verbose=verbose ) print("Discriminator OK.") print_num_parameters(self) # -------------------------- # ----- Optimization ----- # -------------------------- # Get different optimizers for each task (different learning rates) self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer( self.G, self.opts.gen.opt, self.opts.tasks ) if get_num_params(self.D) > 0: self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer( self.D, self.opts.dis.opt, self.opts.tasks, True ) else: self.d_opt, self.d_scheduler = None, None self.losses = get_losses(self.opts, verbose, device=self.device) if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use: self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug) if verbose > 0: for mode, mode_dict in self.all_loaders.items(): for domain, domain_loader in mode_dict.items(): print( "Loader {} {} : {}".format( mode, domain, len(domain_loader.dataset) ) ) # ---------------------------- # ----- Display images ----- # ---------------------------- self.set_display_images() # ------------------------------- # ----- Log Architectures ----- # ------------------------------- self.logger.log_architecture() # ----------------------------- # ----- Set data source ----- # ----------------------------- if self.kitti_pretrain: self.switch_data(to="kitti") else: self.switch_data(to="base") # ------------------------- # ----- Setup Done. ----- # ------------------------- print(" " * 50, end="\r") print("Done creating display images") if self.opts.train.resume: print("Resuming Model (inference: False)") self.resume(False) else: print("Not resuming: starting a new model") print("Setup done.") self.is_setup = True def switch_data(self, to="kitti"): caller = inspect.stack()[1].function print(f"[{caller}] Switching data source to", to) self.data_source = to if to == "kitti": self.display_images = self.kitty_display_images if self.all_loaders is not None: self.loaders = { mode: {"s": self.all_loaders[mode]["kitti"]} for mode in self.all_loaders } else: self.display_images = self.base_display_images if self.all_loaders is not None: self.loaders = { mode: { domain: self.all_loaders[mode][domain] for domain in self.all_loaders[mode] if domain != "kitti" } for mode in self.all_loaders } if ( self.logger.global_step % 2 != 0 and "extra" in self.opts.dis.opt.optimizer.lower() ): print( "Warning: artificially bumping step to run an extrapolation step first." ) self.logger.global_step += 1 def set_display_images(self, use_all=False): for mode, mode_dict in self.all_loaders.items(): if self.kitti_pretrain: self.kitty_display_images[mode] = {} self.base_display_images[mode] = {} for domain in mode_dict: if self.kitti_pretrain and domain == "kitti": target_dict = self.kitty_display_images else: if domain == "kitti": continue target_dict = self.base_display_images dataset = self.all_loaders[mode][domain].dataset display_indices = ( get_display_indices(self.opts, domain, len(dataset)) if not use_all else list(range(len(dataset))) ) ldis = len(display_indices) print( f" Creating {ldis} {mode} {domain} display images...", end="\r", flush=True, ) target_dict[mode][domain] = [ Dict(dataset[i]) for i in display_indices if (print(f"({i})", end="\r") is None and i < len(dataset)) ] if self.exp is not None: for im_id, d in enumerate(target_dict[mode][domain]): self.exp.log_parameter( "display_image_{}_{}_{}".format(mode, domain, im_id), d["paths"], ) def train(self): """For each epoch: * train * eval * save """ assert self.is_setup for self.logger.epoch in range( self.logger.epoch, self.logger.epoch + self.opts.train.epochs ): # backprop painter's disc loss to masker if ( self.logger.epoch == self.opts.gen.p.pl4m_epoch and get_num_params(self.G.painter) > 0 and "p" in self.opts.tasks and self.opts.gen.m.use_pl4m ): print( "\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch) ) self.use_pl4m = True self.run_epoch() self.run_evaluation(verbose=1) self.save() # end vkitti2 pre-training if self.logger.epoch == self.opts.train.kitti.epochs - 1: self.switch_data(to="base") self.kitti_pretrain = False # end pseudo training if self.logger.epoch == self.opts.train.pseudo.epochs - 1: self.pseudo_training_tasks = set() def run_epoch(self): """Runs an epoch: * checks trainer is setup * gets a tuple of batches per domain * sends batches to device * updates sequentially G, D """ assert self.is_setup self.train_mode() if self.exp is not None: self.exp.log_parameter("epoch", self.logger.epoch) epoch_len = min(len(loader) for loader in self.loaders["train"].values()) epoch_desc = "Epoch {}".format(self.logger.epoch) self.logger.time.epoch_start = time() for multi_batch_tuple in tqdm( self.train_loaders, desc=epoch_desc, total=epoch_len, mininterval=0.5, unit="batch", ): self.logger.time.step_start = time() multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple) # The `[0]` is because the domain is contained in a list multi_domain_batch = { batch["domain"][0]: self.batch_to_device(batch) for batch in multi_batch_tuple } # ------------------------------ # ----- Update Generator ----- # ------------------------------ # freeze params of the discriminator if self.d_opt is not None: for param in self.D.parameters(): param.requires_grad = False self.update_G(multi_domain_batch) # ---------------------------------- # ----- Update Discriminator ----- # ---------------------------------- # unfreeze params of the discriminator if self.d_opt is not None and not self.kitti_pretrain: for param in self.D.parameters(): param.requires_grad = True self.update_D(multi_domain_batch) # ------------------------- # ----- Log Metrics ----- # ------------------------- self.logger.global_step += 1 self.logger.log_step_time(time()) if not self.kitti_pretrain: self.update_learning_rates() self.logger.log_learning_rates() self.logger.log_epoch_time(time()) def update_G(self, multi_domain_batch, verbose=0): """Perform an update on g from multi_domain_batch which is a dictionary domain => batch * automatic mixed precision according to self.opts.train.amp * compute loss for each task * loss.backward() * g_opt_step() * g_opt.step() or .extrapolation() depending on self.logger.global_step * logs losses on comet.ml with self.logger.log_losses(model_to_update="G") Args: multi_domain_batch (dict): dictionnary of domain batches """ zero_grad(self.G) if self.opts.train.amp: with autocast(): g_loss = self.get_G_loss(multi_domain_batch, verbose) self.grad_scaler_g.scale(g_loss).backward() self.grad_scaler_g.step(self.g_opt) self.grad_scaler_g.update() else: g_loss = self.get_G_loss(multi_domain_batch, verbose) g_loss.backward() self.g_opt_step() self.logger.log_losses(model_to_update="G", mode="train") def update_D(self, multi_domain_batch, verbose=0): zero_grad(self.D) if self.opts.train.amp: with autocast(): d_loss = self.get_D_loss(multi_domain_batch, verbose) self.grad_scaler_d.scale(d_loss).backward() self.grad_scaler_d.step(self.d_opt) self.grad_scaler_d.update() else: d_loss = self.get_D_loss(multi_domain_batch, verbose) d_loss.backward() self.d_opt_step() self.logger.losses.disc.total_loss = d_loss.item() self.logger.log_losses(model_to_update="D", mode="train") def get_D_loss(self, multi_domain_batch, verbose=0): """Compute the discriminators' losses: * for each domain-specific batch: * encode the image * get the conditioning tensor if using spade * source domain is the data's domain, sequentially r|s then f|n * get the target domain accordingly * compute the translated image from the data * compute the source domain discriminator's loss on the data * compute the target domain discriminator's loss on the translated image # ? In this setting, each D[decoder][domain] is updated twice towards # real or fake data See readme's update d section for details Args: multi_domain_batch ([type]): [description] Returns: [type]: [description] """ disc_loss = { "m": {"Advent": 0}, "s": {"Advent": 0}, } if self.opts.dis.p.use_local_discriminator: disc_loss["p"] = {"global": 0, "local": 0} else: disc_loss["p"] = {"gan": 0} for domain, batch in multi_domain_batch.items(): x = batch["data"]["x"] # --------------------- # ----- Painter ----- # --------------------- if domain == "rf" and self.has_painter: m = batch["data"]["m"] # sample vector with torch.no_grad(): # see spade compute_discriminator_loss fake = self.G.paint(m, x) if self.opts.gen.p.diff_aug.use: fake = self.diff_transforms(fake) x = self.diff_transforms(x) fake = fake.detach() fake.requires_grad_() if self.opts.dis.p.use_local_discriminator: fake_d_global = self.D["p"]["global"](fake) real_d_global = self.D["p"]["global"](x) fake_d_local = self.D["p"]["local"](fake * m) real_d_local = self.D["p"]["local"](x * m) global_loss = self.losses["D"]["p"](fake_d_global, False, True) global_loss += self.losses["D"]["p"](real_d_global, True, True) local_loss = self.losses["D"]["p"](fake_d_local, False, True) local_loss += self.losses["D"]["p"](real_d_local, True, True) disc_loss["p"]["global"] += global_loss disc_loss["p"]["local"] += local_loss else: real_cat = torch.cat([m, x], axis=1) fake_cat = torch.cat([m, fake], axis=1) real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) real_fake_d = self.D["p"](real_fake_cat) real_d, fake_d = divide_pred(real_fake_d) disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True) disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True) # -------------------- # ----- Masker ----- # -------------------- else: z = self.G.encode(x) s_pred = d_pred = cond = z_depth = None if "s" in batch["data"]: if "d" in self.opts.tasks and self.opts.gen.s.use_dada: d_pred, z_depth = self.G.decoders["d"](z) step_loss, s_pred = self.masker_s_loss( x, z, d_pred, z_depth, None, domain, for_="D" ) step_loss *= self.opts.train.lambdas.advent.adv_main disc_loss["s"]["Advent"] += step_loss if "m" in batch["data"]: if "d" in self.opts.tasks: if self.opts.gen.m.use_spade: if d_pred is None: d_pred, z_depth = self.G.decoders["d"](z) cond = self.G.make_m_cond(d_pred, s_pred, x) elif self.opts.gen.m.use_dada: if d_pred is None: d_pred, z_depth = self.G.decoders["d"](z) step_loss, _ = self.masker_m_loss( x, z, None, domain, for_="D", cond=cond, z_depth=z_depth, depth_preds=d_pred, ) step_loss *= self.opts.train.lambdas.advent.adv_main disc_loss["m"]["Advent"] += step_loss self.logger.losses.disc.update( { dom: { k: v.item() if isinstance(v, torch.Tensor) else v for k, v in d.items() } for dom, d in disc_loss.items() } ) loss = sum(v for d in disc_loss.values() for k, v in d.items()) return loss def get_G_loss(self, multi_domain_batch, verbose=0): m_loss = p_loss = None # For now, always compute "representation loss" g_loss = 0 if any(t in self.opts.tasks for t in "msd"): m_loss = self.get_masker_loss(multi_domain_batch) self.logger.losses.gen.masker = m_loss.item() g_loss += m_loss if "p" in self.opts.tasks and not self.kitti_pretrain: p_loss = self.get_painter_loss(multi_domain_batch) self.logger.losses.gen.painter = p_loss.item() g_loss += p_loss assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!" self.logger.losses.gen.total_loss = g_loss.item() return g_loss def get_masker_loss(self, multi_domain_batch): # TODO update docstrings """Only update the representation part of the model, meaning everything but the translation part * for each batch in available domains: * compute task-specific losses * compute the adaptation and translation decoders' auto-encoding losses * compute the adaptation decoder's translation losses (GAN and Cycle) Args: multi_domain_batch (dict): dictionnary mapping domain names to batches from the trainer's loaders Returns: torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas """ m_loss = 0 for domain, batch in multi_domain_batch.items(): # We don't care about the flooded domain here if domain == "rf": continue x = batch["data"]["x"] z = self.G.encode(x) # -------------------------------------- # ----- task-specific losses (2) ----- # -------------------------------------- d_pred = s_pred = z_depth = None for task in ["d", "s", "m"]: if task not in batch["data"]: continue target = batch["data"][task] if task == "d": loss, d_pred, z_depth = self.masker_d_loss( x, z, target, domain, "G" ) m_loss += loss self.logger.losses.gen.task["d"][domain] = loss.item() elif task == "s": loss, s_pred = self.masker_s_loss( x, z, d_pred, z_depth, target, domain, "G" ) m_loss += loss self.logger.losses.gen.task["s"][domain] = loss.item() elif task == "m": cond = None if self.opts.gen.m.use_spade: if not self.opts.gen.m.detach: d_pred = d_pred.clone() s_pred = s_pred.clone() cond = self.G.make_m_cond(d_pred, s_pred, x) loss, _ = self.masker_m_loss( x, z, target, domain, "G", cond=cond, z_depth=z_depth, depth_preds=d_pred, ) m_loss += loss self.logger.losses.gen.task["m"][domain] = loss.item() return m_loss def get_painter_loss(self, multi_domain_batch): """Computes the translation loss when flooding/deflooding images Args: multi_domain_batch (dict): dictionnary mapping domain names to batches from the trainer's loaders Returns: torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas """ step_loss = 0 # self.g_opt.zero_grad() lambdas = self.opts.train.lambdas batch_domain = "rf" batch = multi_domain_batch[batch_domain] x = batch["data"]["x"] # ! different mask: hides water to be reconstructed # ! 1 for water, 0 otherwise m = batch["data"]["m"] fake_flooded = self.G.paint(m, x) # ---------------------- # ----- VGG Loss ----- # ---------------------- if lambdas.G.p.vgg != 0: loss = self.losses["G"]["p"]["vgg"]( vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m) ) loss *= lambdas.G.p.vgg self.logger.losses.gen.p.vgg = loss.item() step_loss += loss # --------------------- # ----- TV Loss ----- # --------------------- if lambdas.G.p.tv != 0: loss = self.losses["G"]["p"]["tv"](fake_flooded * m) loss *= lambdas.G.p.tv self.logger.losses.gen.p.tv = loss.item() step_loss += loss # -------------------------- # ----- Context Loss ----- # -------------------------- if lambdas.G.p.context != 0: loss = self.losses["G"]["p"]["context"](fake_flooded, x, m) loss *= lambdas.G.p.context self.logger.losses.gen.p.context = loss.item() step_loss += loss # --------------------------------- # ----- Reconstruction Loss ----- # --------------------------------- if lambdas.G.p.reconstruction != 0: loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m) loss *= lambdas.G.p.reconstruction self.logger.losses.gen.p.reconstruction = loss.item() step_loss += loss # ------------------------------------- # ----- Local & Global GAN Loss ----- # ------------------------------------- if self.opts.gen.p.diff_aug.use: fake_flooded = self.diff_transforms(fake_flooded) x = self.diff_transforms(x) if self.opts.dis.p.use_local_discriminator: fake_d_global = self.D["p"]["global"](fake_flooded) fake_d_local = self.D["p"]["local"](fake_flooded * m) real_d_global = self.D["p"]["global"](x) # Note: discriminator returns [out_1,...,out_num_D] outputs # Each out_i is a list [feat1, feat2, ..., pred_i] self.logger.losses.gen.p.gan = 0 loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) loss *= lambdas.G["p"]["gan"] self.logger.losses.gen.p.gan = loss.item() step_loss += loss # ----------------------------------- # ----- Feature Matching Loss ----- # ----------------------------------- # (only on global discriminator) # Order must be real, fake if self.opts.dis.p.get_intermediate_features: loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global) loss *= lambdas.G["p"]["featmatch"] if isinstance(loss, float): self.logger.losses.gen.p.featmatch = loss else: self.logger.losses.gen.p.featmatch = loss.item() step_loss += loss # ------------------------------------------- # ----- Single Discriminator GAN Loss ----- # ------------------------------------------- else: real_cat = torch.cat([m, x], axis=1) fake_cat = torch.cat([m, fake_flooded], axis=1) real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) real_fake_d = self.D["p"](real_fake_cat) real_d, fake_d = divide_pred(real_fake_d) loss = self.losses["G"]["p"]["gan"](fake_d, True, False) self.logger.losses.gen.p.gan = loss.item() step_loss += loss # ----------------------------------- # ----- Feature Matching Loss ----- # ----------------------------------- if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0: loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d) loss *= lambdas.G.p.featmatch if isinstance(loss, float): self.logger.losses.gen.p.featmatch = loss else: self.logger.losses.gen.p.featmatch = loss.item() step_loss += loss return step_loss def masker_d_loss(self, x, z, target, domain, for_="G"): assert for_ in {"G", "D"} self.assert_z_matches_x(x, z) assert x.shape[0] == target.shape[0] zero_loss = torch.tensor(0.0, device=self.device) weight = self.opts.train.lambdas.G.d.main prediction, z_depth = self.G.decoders["d"](z) if self.opts.gen.d.classify.enable: target.squeeze_(1) full_loss = self.losses["G"]["tasks"]["d"](prediction, target) full_loss *= weight if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks): return zero_loss, prediction, z_depth return full_loss, prediction, z_depth def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"): assert for_ in {"G", "D"} assert domain in {"r", "s"} self.assert_z_matches_x(x, z) assert x.shape[0] == target.shape[0] if target is not None else True full_loss = torch.tensor(0.0, device=self.device) softmax_preds = None # -------------------------- # ----- Segmentation ----- # -------------------------- pred = None if for_ == "G" or self.opts.gen.s.use_advent: pred = self.G.decoders["s"](z, z_depth) # Supervised segmentation loss: crossent for sim domain, # crossent_pseudo for real ; loss is crossent in any case if for_ == "G": if domain == "s" or "s" in self.pseudo_training_tasks: if domain == "s": logger = self.logger.losses.gen.task["s"]["crossent"] weight = self.opts.train.lambdas.G["s"]["crossent"] else: logger = self.logger.losses.gen.task["s"]["crossent_pseudo"] weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"] if weight != 0: # Cross-Entropy loss loss_func = self.losses["G"]["tasks"]["s"]["crossent"] loss = loss_func(pred, target.squeeze(1)) loss *= weight full_loss += loss logger[domain] = loss.item() if domain == "r": weight = self.opts.train.lambdas.G["s"]["minent"] if self.opts.gen.s.use_minent and weight != 0: softmax_preds = softmax(pred, dim=1) # Entropy minimization loss loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds) loss *= weight full_loss += loss self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item() # Fool ADVENT discriminator if self.opts.gen.s.use_advent: if self.opts.gen.s.use_dada and depth_preds is not None: depth_preds = depth_preds.detach() else: depth_preds = None if for_ == "D": domain_label = domain logger = {} loss_func = self.losses["D"]["advent"] pred = pred.detach() weight = self.opts.train.lambdas.advent.adv_main else: domain_label = "s" logger = self.logger.losses.gen.task["s"]["advent"] loss_func = self.losses["G"]["tasks"]["s"]["advent"] weight = self.opts.train.lambdas.G["s"]["advent"] if (for_ == "D" or domain == "r") and weight != 0: if softmax_preds is None: softmax_preds = softmax(pred, dim=1) loss = loss_func( softmax_preds, self.domain_labels[domain_label], self.D["s"]["Advent"], depth_preds, ) loss *= weight full_loss += loss logger[domain] = loss.item() if for_ == "D": # WGAN: clipping or GP if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm": pass elif self.opts.dis.s.gan_type == "WGAN": for p in self.D["s"]["Advent"].parameters(): p.data.clamp_( self.opts.dis.s.wgan_clamp_lower, self.opts.dis.s.wgan_clamp_upper, ) elif self.opts.dis.s.gan_type == "WGAN_gp": prob_need_grad = autograd.Variable(pred, requires_grad=True) d_out = self.D["s"]["Advent"](prob_need_grad) gp = get_WGAN_gradient(prob_need_grad, d_out) gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp full_loss += gp_loss else: raise NotImplementedError return full_loss, pred def masker_m_loss( self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None ): assert for_ in {"G", "D"} assert domain in {"r", "s"} self.assert_z_matches_x(x, z) assert x.shape[0] == target.shape[0] if target is not None else True full_loss = torch.tensor(0.0, device=self.device) pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth) pred_prob = sigmoid(pred_logits) pred_prob_complementary = 1 - pred_prob prob = torch.cat([pred_prob, pred_prob_complementary], dim=1) if for_ == "G": # TV loss weight = self.opts.train.lambdas.G.m.tv if weight != 0: loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob) loss *= weight full_loss += loss self.logger.losses.gen.task["m"]["tv"][domain] = loss.item() weight = self.opts.train.lambdas.G.m.bce if domain == "s" and weight != 0: # CrossEnt Loss loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target) loss *= weight full_loss += loss self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item() if domain == "r": weight = self.opts.train.lambdas.G["m"]["gi"] if self.opts.gen.m.use_ground_intersection and weight != 0: # GroundIntersection loss loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target) loss *= weight full_loss += loss self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item() weight = self.opts.train.lambdas.G.m.pl4m if self.use_pl4m and weight != 0: # Painter loss pl4m_loss = self.painter_loss_for_masker(x, pred_prob) pl4m_loss *= weight full_loss += pl4m_loss self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item() weight = self.opts.train.lambdas.advent.ent_main if self.opts.gen.m.use_minent and weight != 0: # MinEnt loss loss = self.losses["G"]["tasks"]["m"]["minent"](prob) loss *= weight full_loss += loss self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item() if self.opts.gen.m.use_advent: # AdvEnt loss if self.opts.gen.m.use_dada and depth_preds is not None: depth_preds = depth_preds.detach() depth_preds = torch.nn.functional.interpolate( depth_preds, size=x.shape[-2:], mode="nearest" ) else: depth_preds = None if for_ == "D": domain_label = domain logger = {} loss_func = self.losses["D"]["advent"] prob = prob.detach() weight = self.opts.train.lambdas.advent.adv_main else: domain_label = "s" logger = self.logger.losses.gen.task["m"]["advent"] loss_func = self.losses["G"]["tasks"]["m"]["advent"] weight = self.opts.train.lambdas.advent.adv_main if (for_ == "D" or domain == "r") and weight != 0: loss = loss_func( prob.to(self.device), self.domain_labels[domain_label], self.D["m"]["Advent"], depth_preds, ) loss *= weight full_loss += loss logger[domain] = loss.item() if for_ == "D": # WGAN: clipping or GP if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm": pass elif self.opts.dis.m.gan_type == "WGAN": for p in self.D["s"]["Advent"].parameters(): p.data.clamp_( self.opts.dis.m.wgan_clamp_lower, self.opts.dis.m.wgan_clamp_upper, ) elif self.opts.dis.m.gan_type == "WGAN_gp": prob_need_grad = autograd.Variable(prob, requires_grad=True) d_out = self.D["s"]["Advent"](prob_need_grad) gp = get_WGAN_gradient(prob_need_grad, d_out) gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp full_loss += gp_loss else: raise NotImplementedError return full_loss, prob def painter_loss_for_masker(self, x, m): # pl4m loss # painter should not be updated for param in self.G.painter.parameters(): param.requires_grad = False # TODO for param in self.D.painter.parameters(): # param.requires_grad = False fake_flooded = self.G.paint(m, x) if self.opts.dis.p.use_local_discriminator: fake_d_global = self.D["p"]["global"](fake_flooded) fake_d_local = self.D["p"]["local"](fake_flooded * m) # Note: discriminator returns [out_1,...,out_num_D] outputs # Each out_i is a list [feat1, feat2, ..., pred_i] pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) else: real_cat = torch.cat([m, x], axis=1) fake_cat = torch.cat([m, fake_flooded], axis=1) real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) real_fake_d = self.D["p"](real_fake_cat) _, fake_d = divide_pred(real_fake_d) pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False) if "p" in self.opts.tasks: for param in self.G.painter.parameters(): param.requires_grad = True return pl4m_loss @torch.no_grad() def run_evaluation(self, verbose=0): print("******************* Running Evaluation ***********************") start_time = time() self.eval_mode() val_logger = None nb_of_batches = None for i, multi_batch_tuple in enumerate(self.val_loaders): # create a dictionnary (domain => batch) from tuple # (batch_domain_0, ..., batch_domain_i) # and send it to self.device nb_of_batches = i + 1 multi_domain_batch = { batch["domain"][0]: self.batch_to_device(batch) for batch in multi_batch_tuple } self.get_G_loss(multi_domain_batch, verbose) if val_logger is None: val_logger = deepcopy(self.logger.losses.generator) else: val_logger = sum_dict(val_logger, self.logger.losses.generator) val_logger = div_dict(val_logger, nb_of_batches) self.logger.losses.generator = val_logger self.logger.log_losses(model_to_update="G", mode="val") for d in self.opts.domains: self.logger.log_comet_images("train", d) self.logger.log_comet_images("val", d) if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain: self.logger.log_comet_combined_images("train", "r") self.logger.log_comet_combined_images("val", "r") if self.exp is not None: print() if "m" in self.opts.tasks or "s" in self.opts.tasks: self.eval_images("val", "r") self.eval_images("val", "s") if "p" in self.opts.tasks and not self.kitti_pretrain: val_fid = compute_val_fid(self) if self.exp is not None: self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step) else: print("Validation FID Score", val_fid) self.train_mode() timing = int(time() - start_time) print("****************** Done in {}s *********************".format(timing)) def eval_images(self, mode, domain): if domain == "s" and self.kitti_pretrain: domain = "kitti" if domain == "rf" or domain not in self.display_images[mode]: return metric_funcs = {"accuracy": accuracy, "mIOU": mIOU} metric_avg_scores = {"m": {}} if "s" in self.opts.tasks: metric_avg_scores["s"] = {} if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable: metric_avg_scores["d"] = {} for key in metric_funcs: for task in metric_avg_scores: metric_avg_scores[task][key] = [] for im_set in self.display_images[mode][domain]: x = im_set["data"]["x"].unsqueeze(0).to(self.device) z = self.G.encode(x) s_pred = d_pred = z_depth = None if "d" in metric_avg_scores: d_pred, z_depth = self.G.decoders["d"](z) d_pred = d_pred.detach().cpu() if domain == "s": d = im_set["data"]["d"].unsqueeze(0).detach() for metric in metric_funcs: metric_score = metric_funcs[metric](d_pred, d) metric_avg_scores["d"][metric].append(metric_score) if "s" in metric_avg_scores: if z_depth is None: if self.opts.gen.s.use_dada and "d" in self.opts.tasks: _, z_depth = self.G.decoders["d"](z) s_pred = self.G.decoders["s"](z, z_depth).detach().cpu() s = im_set["data"]["s"].unsqueeze(0).detach() for metric in metric_funcs: metric_score = metric_funcs[metric](s_pred, s) metric_avg_scores["s"][metric].append(metric_score) if "m" in self.opts: cond = None if s_pred is not None and d_pred is not None: cond = self.G.make_m_cond(d_pred, s_pred, x) if z_depth is None: if self.opts.gen.m.use_dada and "d" in self.opts.tasks: _, z_depth = self.G.decoders["d"](z) pred_mask = ( (self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu() ) pred_mask = (pred_mask > 0.5).to(torch.float32) pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1) m = im_set["data"]["m"].unsqueeze(0).detach() for metric in metric_funcs: if metric != "mIOU": metric_score = metric_funcs[metric](pred_mask, m) else: metric_score = metric_funcs[metric](pred_prob, m) metric_avg_scores["m"][metric].append(metric_score) metric_avg_scores = { task: { metric: np.mean(values) if values else float("nan") for metric, values in met_dict.items() } for task, met_dict in metric_avg_scores.items() } metric_avg_scores = { task: { metric: value if not np.isnan(value) else -1 for metric, value in met_dict.items() } for task, met_dict in metric_avg_scores.items() } if self.exp is not None: self.exp.log_metrics( flatten_opts(metric_avg_scores), prefix=f"metrics_{mode}_{domain}", step=self.logger.global_step, ) else: print(f"metrics_{mode}_{domain}") print(flatten_opts(metric_avg_scores)) return 0 def functional_test_mode(self): import atexit self.opts.output_path = ( Path("~").expanduser() / "climategan" / "functional_tests" ) Path(self.opts.output_path).mkdir(parents=True, exist_ok=True) with open(Path(self.opts.output_path) / "is_functional.test", "w") as f: f.write("trainer functional test - delete this dir") if self.exp is not None: self.exp.log_parameter("is_functional_test", True) atexit.register(self.del_output_path) def del_output_path(self, force=False): import shutil if not Path(self.opts.output_path).exists(): return if (Path(self.opts.output_path) / "is_functional.test").exists() or force: shutil.rmtree(self.opts.output_path) def compute_fire(self, x, seg_preds=None, z=None, z_depth=None): """ Transforms input tensor given wildfires event Args: x (torch.Tensor): Input tensor seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor z (torch.Tensor): Latent vector of encoded "x". Can be None if seg_preds is given. Returns: torch.Tensor: Wildfire version of input tensor """ if seg_preds is None: if z is None: z = self.G.encode(x) seg_preds = self.G.decoders["s"](z, z_depth) return add_fire(x, seg_preds, self.opts.events.fire) def compute_flood( self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1 ): """ Applies a flood (mask + paint) to an input image, with optionally pre-computed masker z or mask Args: x (torch.Tensor): B x C x H x W -1:1 input image z (torch.Tensor, optional): B x C x H x W Masker latent vector. Defaults to None. m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None. bin_value (float, optional): Mask binarization value. Set to -1 to use smooth masks (no binarization) Returns: torch.Tensor: B x 3 x H x W -1:1 flooded image """ if m is None: if z is None: z = self.G.encode(x) if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None: _, z_depth = self.G.decoders["d"](z) m = self.G.mask(x=x, z=z, z_depth=z_depth) if bin_value >= 0: m = (m > bin_value).to(m.dtype) if cloudy: assert s is not None return self.G.paint_cloudy(m, x, s) return self.G.paint(m, x) def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False): # implementation from the paper: # HazeRD: An outdoor scene dataset and benchmark for single image dehazing sky_mask = None if d is None or (use_sky_seg and s is None): if z is None: z = self.G.encode(x) if d is None: d, _ = self.G.decoders["d"](z) if use_sky_seg and s is None: if "s" not in self.opts.tasks: raise ValueError( "Cannot have " + "(use_sky_seg is True and s is None and 's' not in tasks)" ) s = self.G.decoders["s"](z) # TODO: s to sky mask # TODO: interpolate to d's size params = self.opts.events.smog airlight = params.airlight * torch.ones(3) airlight = airlight.view(1, -1, 1, 1).to(self.device) irradiance = srgb2lrgb(x) beta = torch.tensor([params.beta / params.vr] * 3) beta = beta.view(1, -1, 1, 1).to(self.device) d = normalize(d, mini=0.3, maxi=1.0) d = 1.0 / d d = normalize(d, mini=0.1, maxi=1) if sky_mask is not None: d[sky_mask] = 1 d = torch.nn.functional.interpolate( d, size=x.shape[-2:], mode="bilinear", align_corners=True ) d = d.repeat(1, 3, 1, 1) transmission = torch.exp(d * -beta) smogged = transmission * irradiance + (1 - transmission) * airlight smogged = lrgb2srgb(smogged) # add yellow filter alpha = params.alpha / 255 yellow_mask = torch.Tensor([params.yellow_color]) / 255 yellow_filter = ( yellow_mask.unsqueeze(2) .unsqueeze(2) .repeat(1, 1, smogged.shape[-2], smogged.shape[-1]) .to(self.device) ) smogged = smogged * (1 - alpha) + yellow_filter * alpha return smogged