feng2022's picture
anothertry
89d1ee7
from argparse import (
ArgumentParser,
Namespace,
)
from typing import (
Dict,
Iterable,
Optional,
Tuple,
)
import numpy as np
import torch
from torch import nn
from utils.misc import (
optional_string,
iterable_to_str,
)
from .contextual_loss import ContextualLoss
from .color_transfer_loss import ColorTransferLoss
from .regularize_noise import NoiseRegularizer
from .reconstruction import (
EyeLoss,
FaceLoss,
create_perceptual_loss,
ReconstructionArguments,
)
class LossArguments:
@staticmethod
def add_arguments(parser: ArgumentParser):
ReconstructionArguments.add_arguments(parser)
parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
parser.add_argument('--noise_regularize', type=float, default=5e4)
# contextual loss
parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
default=['relu3_4', 'relu2_2', 'relu1_2'])
@staticmethod
def to_string(args: Namespace) -> str:
return (
ReconstructionArguments.to_string(args)
+ optional_string(args.eye > 0, f"-eye{args.eye}")
+ optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
+ optional_string(
args.contextual,
f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
)
#+ optional_string(args.mse, f"-mse{args.mse}")
+ optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
)
class BakedMultiContextualLoss(nn.Module):
"""Random sample different image patches for different vgg layers."""
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
super().__init__()
self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
for layer in args.cx_layers])
self.size = size
self.sibling = sibling.detach()
def forward(self, img: torch.Tensor):
cx_loss = 0
for cx in self.cxs:
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
return cx_loss
class BakedContextualLoss(ContextualLoss):
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
self.size = size
self.sibling = sibling.detach()
def forward(self, img: torch.Tensor):
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
class JointLoss(nn.Module):
def __init__(
self,
args: Namespace,
target: torch.Tensor,
sibling: Optional[torch.Tensor],
sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
):
super().__init__()
self.weights = {
"face": 1., "eye": args.eye,
"contextual": args.contextual, "color_transfer": args.color_transfer,
"noise": args.noise_regularize,
}
reconstruction = {}
if args.vgg > 0 or args.vggface > 0:
percept = create_perceptual_loss(args)
reconstruction.update(
{"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
)
if args.eye > 0:
reconstruction.update(
{"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
)
self.reconstruction = nn.ModuleDict(reconstruction)
exemplar = {}
if args.contextual > 0 and len(args.cx_layers) > 0:
assert sibling is not None
exemplar.update(
{"contextual": BakedContextualLoss(sibling, args)}
)
if args.color_transfer > 0:
assert sibling_rgbs is not None
self.sibling_rgbs = sibling_rgbs
exemplar.update(
{"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
)
self.exemplar = nn.ModuleDict(exemplar)
if args.noise_regularize > 0:
self.noise_criterion = NoiseRegularizer()
def forward(
self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
rgbs: results from the ToRGB layers
"""
# TODO: add current optimization resolution for noises
losses = {}
# reconstruction losses
for name, criterion in self.reconstruction.items():
losses[name] = criterion(img, degrade=degrade)
# exemplar losses
if 'contextual' in self.exemplar:
losses["contextual"] = self.exemplar["contextual"](img)
if "color_transfer" in self.exemplar:
assert rgbs is not None
losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
# noise regularizer
if self.weights["noise"] > 0:
losses["noise"] = self.noise_criterion(noises)
total_loss = 0
for name, loss in losses.items():
total_loss = total_loss + self.weights[name] * loss
return total_loss, losses
def update_sibling(self, sibling: torch.Tensor):
assert "contextual" in self.exemplar
self.exemplar["contextual"].sibling = sibling.detach()