feng2022's picture
anothertry
89d1ee7
from argparse import (
ArgumentParser,
Namespace,
)
from typing import Optional
import numpy as np
import torch
from torch import nn
from losses.perceptual_loss import PerceptualLoss
from models.degrade import Downsample
from utils.misc import optional_string
class ReconstructionArguments:
@staticmethod
def add_arguments(parser: ArgumentParser):
parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
parser.add_argument("--vgg", type=float, default=1, help="vgg")
parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
@staticmethod
def to_string(args: Namespace) -> str:
return (
f"s{args.recon_size}"
+ optional_string(args.vgg > 0, f"-vgg{args.vgg}")
+ optional_string(args.vggface > 0, f"-vggface{args.vggface}")
)
def create_perceptual_loss(args: Namespace):
return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
class EyeLoss(nn.Module):
def __init__(
self,
target: torch.Tensor,
input_size: int = 1024,
input_channels: int = 3,
percept: Optional[nn.Module] = None,
args: Optional[Namespace] = None
):
"""
target: target image
"""
assert not (percept is None and args is None)
super().__init__()
self.target = target
target_size = target.shape[-1]
self.downsample = Downsample(input_size, target_size, input_channels) \
if target_size != input_size else (lambda x: x)
self.percept = percept if percept is not None else create_perceptual_loss(args)
eye_size = np.array((224, 224))
btlrs = []
for sgn in [1, -1]:
center = np.array((480, 384 * sgn)) # (y, x)
b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
self.btlrs = np.stack(btlrs, axis=0)
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
"""
img: it should be the degraded version of the generated image
"""
if degrade is not None:
img = degrade(img, downsample=self.downsample)
loss = 0
for (b, t, l, r) in self.btlrs:
loss = loss + self.percept(
img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
use_vggface=False, max_vgg_layer=4,
# use_vgg=False,
)
return loss
class FaceLoss(nn.Module):
def __init__(
self,
target: torch.Tensor,
input_size: int = 1024,
input_channels: int = 3,
size: int = 256,
percept: Optional[nn.Module] = None,
args: Optional[Namespace] = None
):
"""
target: target image
"""
assert not (percept is None and args is None)
super().__init__()
target_size = target.shape[-1]
self.target = target if target_size == size \
else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
self.downsample = Downsample(input_size, size, input_channels) \
if size != input_size else (lambda x: x)
self.percept = percept if percept is not None else create_perceptual_loss(args)
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
"""
img: it should be the degraded version of the generated image
"""
if degrade is not None:
img = degrade(img, downsample=self.downsample)
loss = self.percept(img, self.target)
return loss