|
"""Complete Generator architecture: |
|
* OmniGenerator |
|
* Encoder |
|
* Decoders |
|
""" |
|
from pathlib import Path |
|
import traceback |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import yaml |
|
from addict import Dict |
|
from torch import softmax |
|
|
|
import climategan.strings as strings |
|
from climategan.deeplab import create_encoder, create_segmentation_decoder |
|
from climategan.depth import create_depth_decoder |
|
from climategan.masker import create_mask_decoder |
|
from climategan.painter import create_painter |
|
from climategan.tutils import init_weights, mix_noise, normalize |
|
|
|
|
|
def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0): |
|
G = OmniGenerator(opts, latent_shape, verbose, no_init) |
|
if no_init: |
|
print("Sending to", device) |
|
return G.to(device) |
|
|
|
for model in G.decoders: |
|
net = G.decoders[model] |
|
if model == "s": |
|
continue |
|
if isinstance(net, nn.ModuleDict): |
|
for domain, domain_model in net.items(): |
|
init_weights( |
|
net[domain_model], |
|
init_type=opts.gen[model].init_type, |
|
init_gain=opts.gen[model].init_gain, |
|
verbose=verbose, |
|
caller=f"create_generator decoder {model} {domain}", |
|
) |
|
else: |
|
init_weights( |
|
G.decoders[model], |
|
init_type=opts.gen[model].init_type, |
|
init_gain=opts.gen[model].init_gain, |
|
verbose=verbose, |
|
caller=f"create_generator decoder {model}", |
|
) |
|
if G.encoder is not None and opts.gen.encoder.architecture == "base": |
|
init_weights( |
|
G.encoder, |
|
init_type=opts.gen.encoder.init_type, |
|
init_gain=opts.gen.encoder.init_gain, |
|
verbose=verbose, |
|
caller="create_generator encoder", |
|
) |
|
|
|
print("Sending to", device) |
|
return G.to(device) |
|
|
|
|
|
class OmniGenerator(nn.Module): |
|
def __init__(self, opts, latent_shape=None, verbose=0, no_init=False): |
|
"""Creates the generator. All decoders listed in opts.gen will be added |
|
to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True. |
|
Then can be accessed as G.decoders.T or G.decoders["T"] for instance, |
|
for the image Translation decoder |
|
|
|
Args: |
|
opts (addict.Dict): configuration dict |
|
""" |
|
super().__init__() |
|
self.opts = opts |
|
self.verbose = verbose |
|
self.encoder = None |
|
if any(t in opts.tasks for t in "msd"): |
|
self.encoder = create_encoder(opts, no_init, verbose) |
|
|
|
self.decoders = {} |
|
self.painter = nn.Module() |
|
|
|
if "d" in opts.tasks: |
|
self.decoders["d"] = create_depth_decoder(opts, no_init, verbose) |
|
|
|
if self.verbose > 0: |
|
print(f" - Add {self.decoders['d'].__class__.__name__}") |
|
|
|
if "s" in opts.tasks: |
|
self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose) |
|
|
|
if "m" in opts.tasks: |
|
self.decoders["m"] = create_mask_decoder(opts, no_init, verbose) |
|
|
|
self.decoders = nn.ModuleDict(self.decoders) |
|
|
|
if "p" in self.opts.tasks: |
|
self.painter = create_painter(opts, no_init, verbose) |
|
else: |
|
if self.verbose > 0: |
|
print(" - Add Empty Painter") |
|
|
|
def __str__(self): |
|
return strings.generator(self) |
|
|
|
def encode(self, x): |
|
""" |
|
Forward x through the encoder |
|
|
|
Args: |
|
x (torch.Tensor): B3HW input tensor |
|
|
|
Returns: |
|
list: High and Low level features from the encoder |
|
""" |
|
assert self.encoder is not None |
|
return self.encoder.forward(x) |
|
|
|
def decode(self, x=None, z=None, return_z=False, return_z_depth=False): |
|
""" |
|
Comptutes the predictions of all available decoders from either x or z. |
|
If using spade for the masker with 15 channels, x *must* be provided, |
|
whether z is too or not. |
|
|
|
Args: |
|
x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None. |
|
z (list, optional): List of high and low-level features as BCHW. |
|
Defaults to None. |
|
return_z (bool, optional): whether or not to return z in the dict. |
|
Defaults to False. |
|
return_z_depth (bool, optional): whether or not to return z_depth |
|
in the dict. Defaults to False. |
|
|
|
Raises: |
|
ValueError: If using spade for the masker with 15 channels but x is None |
|
|
|
Returns: |
|
dict: {task: prediction_tensor} (may include z and z_depth |
|
depending on args) |
|
""" |
|
|
|
assert x is not None or z is not None |
|
if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15: |
|
if x is None: |
|
raise ValueError( |
|
"When using spade for the Masker with 15 channels," |
|
+ " x MUST be provided" |
|
) |
|
|
|
z_depth = cond = d = s = None |
|
out = {} |
|
|
|
if z is None: |
|
z = self.encode(x) |
|
|
|
if return_z: |
|
out["z"] = z |
|
|
|
if "d" in self.decoders: |
|
d, z_depth = self.decoders["d"](z) |
|
out["d"] = d |
|
|
|
if return_z_depth: |
|
out["z_depth"] = z_depth |
|
|
|
if "s" in self.decoders: |
|
s = self.decoders["s"](z, z_depth) |
|
out["s"] = s |
|
|
|
if "m" in self.decoders: |
|
if s is not None and d is not None: |
|
cond = self.make_m_cond(d, s, x) |
|
m = self.mask(z=z, cond=cond) |
|
out["m"] = m |
|
|
|
return out |
|
|
|
def sample_painter_z(self, batch_size, device, force_half=False): |
|
if self.opts.gen.p.no_z: |
|
return None |
|
|
|
z = torch.empty( |
|
batch_size, |
|
self.opts.gen.p.latent_dim, |
|
self.painter.z_h, |
|
self.painter.z_w, |
|
device=device, |
|
).normal_(mean=0, std=1.0) |
|
|
|
if force_half: |
|
z = z.half() |
|
|
|
return z |
|
|
|
def make_m_cond(self, d, s, x=None): |
|
""" |
|
Create the masker's conditioning input when using spade from the |
|
d and s predictions and from the input x when cond_nc == 15. |
|
|
|
d and s are assumed to have the the same spatial resolution. |
|
if cond_nc == 15 then x is interpolated to match that dimension. |
|
|
|
Args: |
|
d (torch.Tensor): Raw depth prediction (B1HW) |
|
s (torch.Tensor): Raw segmentation prediction (BCHW) |
|
x (torch.Tensor, optional): Input tensor (B3hW). Mandatory |
|
when opts.gen.m.spade.cond_nc == 15 |
|
|
|
Raises: |
|
ValueError: opts.gen.m.spade.cond_nc == 15 but x is None |
|
|
|
Returns: |
|
torch.Tensor: B x cond_nc x H x W conditioning tensor. |
|
""" |
|
if self.opts.gen.m.spade.detach: |
|
d = d.detach() |
|
s = s.detach() |
|
cats = [normalize(d), softmax(s, dim=1)] |
|
if self.opts.gen.m.spade.cond_nc == 15: |
|
if x is None: |
|
raise ValueError( |
|
"When using spade for the Masker with 15 channels," |
|
+ " x MUST be provided" |
|
) |
|
cats += [ |
|
F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True) |
|
] |
|
|
|
return torch.cat(cats, dim=1) |
|
|
|
def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True): |
|
""" |
|
Create a mask from either an input x or a latent vector z. |
|
Optionally if the Masker has a spade architecture the conditioning tensor |
|
may be provided (cond). Default behavior applies an element-wise |
|
sigmoid, but can be deactivated (sigmoid=False). |
|
|
|
At least one of x or z must be provided (i.e. not None). |
|
If the Masker has a spade architecture and cond_nc == 15 then x cannot |
|
be None. |
|
|
|
Args: |
|
x (torch.Tensor, optional): Input tensor B3HW. Defaults to None. |
|
z (list, optional): High and Low level features of the encoder. |
|
Will be computed if None. Defaults to None. |
|
cond ([type], optional): [description]. Defaults to None. |
|
sigmoid (bool, optional): [description]. Defaults to True. |
|
|
|
Returns: |
|
torch.Tensor: B1HW mask tensor |
|
""" |
|
assert x is not None or z is not None |
|
if z is None: |
|
z = self.encode(x) |
|
|
|
if cond is None and self.opts.gen.m.use_spade: |
|
assert "s" in self.opts.tasks and "d" in self.opts.tasks |
|
with torch.no_grad(): |
|
d_pred, z_d = self.decoders["d"](z) |
|
s_pred = self.decoders["s"](z, z_d) |
|
cond = self.make_m_cond(d_pred, s_pred, x) |
|
if z_depth is None and self.opts.gen.m.use_dada: |
|
assert "d" in self.opts.tasks |
|
with torch.no_grad(): |
|
_, z_depth = self.decoders["d"](z) |
|
|
|
if cond is not None: |
|
device = z[0].device if isinstance(z, (tuple, list)) else z.device |
|
cond = cond.to(device) |
|
|
|
logits = self.decoders["m"](z, cond, z_depth) |
|
|
|
if not sigmoid: |
|
return logits |
|
|
|
return torch.sigmoid(logits) |
|
|
|
def paint(self, m, x, no_paste=False): |
|
""" |
|
Paints given a mask and an image |
|
calls painter(z, x * (1.0 - m)) |
|
Mask has 1s where water should be painted |
|
|
|
Args: |
|
m (torch.Tensor): Mask |
|
x (torch.Tensor): Image to paint |
|
|
|
Returns: |
|
torch.Tensor: painted image |
|
""" |
|
z_paint = self.sample_painter_z(x.shape[0], x.device) |
|
m = m.to(x.dtype) |
|
fake = self.painter(z_paint, x * (1.0 - m)) |
|
if self.opts.gen.p.paste_original_content and not no_paste: |
|
return x * (1.0 - m) + fake * m |
|
return fake |
|
|
|
def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8): |
|
""" |
|
Paints x with water in m through an intermediary cloudy image |
|
where the sky has been replaced with perlin noise to imitate clouds. |
|
|
|
The intermediary cloudy image is only used to control the painter's |
|
painting mode, probing it with a cloudy input. |
|
|
|
Args: |
|
m (torch.Tensor): water mask |
|
x (torch.Tensor): input tensor |
|
s (torch.Tensor): segmentation prediction (BCHW) |
|
sky_idx (int, optional): Index of the sky class along s's C dimension. |
|
Defaults to 9. |
|
res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8). |
|
weight (float, optional): Intermediate image's cloud proportion |
|
(w * cloud + (1-w) * original_sky). Defaults to 0.8. |
|
|
|
Returns: |
|
torch.Tensor: painted image with original content pasted. |
|
""" |
|
sky_mask = ( |
|
torch.argmax( |
|
F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True |
|
) |
|
== sky_idx |
|
).to(x.dtype) |
|
noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype) |
|
fake = self.paint(m, noised_x, no_paste=True) |
|
return x * (1.0 - m) + fake * m |
|
|
|
def depth(self, x=None, z=None, return_z_depth=False): |
|
""" |
|
Compute the depth head's output |
|
|
|
Args: |
|
x (torch.Tensor, optional): Input B3HW tensor. Defaults to None. |
|
z (list, optional): High and Low level features of the encoder. |
|
Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: B1HW tensor of depth predictions |
|
""" |
|
assert x is not None or z is not None |
|
assert not (x is not None and z is not None) |
|
if z is None: |
|
z = self.encode(x) |
|
depth, z_depth = self.decoders["d"](z) |
|
|
|
if depth.shape[1] > 1: |
|
depth = torch.argmax(depth, dim=1) |
|
depth = depth / depth.max() |
|
|
|
if return_z_depth: |
|
return depth, z_depth |
|
|
|
return depth |
|
|
|
def load_val_painter(self): |
|
""" |
|
Loads a validation painter if available in opts.val.val_painter |
|
|
|
Returns: |
|
bool: operation success status |
|
""" |
|
try: |
|
|
|
assert self.opts.val.val_painter |
|
|
|
|
|
ckpt_path = Path(self.opts.val.val_painter).resolve() |
|
assert ckpt_path.exists() |
|
|
|
|
|
assert ckpt_path.is_file() |
|
|
|
|
|
opts_path = ckpt_path.parent.parent / "opts.yaml" |
|
assert opts_path.exists() |
|
|
|
|
|
with opts_path.open("r") as f: |
|
val_painter_opts = Dict(yaml.safe_load(f)) |
|
|
|
|
|
state_dict = torch.load(ckpt_path) |
|
|
|
|
|
painter = create_painter(val_painter_opts) |
|
|
|
|
|
painter.load_state_dict( |
|
{k.replace("painter.", ""): v for k, v in state_dict["G"].items()} |
|
) |
|
|
|
|
|
device = next(self.parameters()).device |
|
self.painter = painter.eval().to(device) |
|
|
|
|
|
for p in self.painter.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
print(" - Loaded validation-only painter") |
|
return True |
|
|
|
except Exception as e: |
|
|
|
print(traceback.format_exc()) |
|
print(e) |
|
print(">>> WARNING: error (^) in load_val_painter, aborting.") |
|
return False |
|
|