feng2022's picture
anothertry
89d1ee7
from argparse import (
ArgumentParser,
Namespace,
)
import torch
from torch import nn
from torch.nn import functional as F
from utils.misc import optional_string
from .gaussian_smoothing import GaussianSmoothing
class DegradeArguments:
@staticmethod
def add_arguments(parser: ArgumentParser):
parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
parser.add_argument('--gaussian', type=float, default=0,
help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
@staticmethod
def to_string(args: Namespace) -> str:
return (
f"{args.spectral_sensitivity}"
+ optional_string(args.gaussian > 0, f"-G{args.gaussian}")
)
class CameraResponse(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
self.register_parameter("gain", nn.Parameter(torch.ones(1)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.clamp(x, max=1, min=-1+1e-2)
x = (1 + x) * 0.5
x = self.offset + self.gain * torch.pow(x, self.gamma)
x = (x - 0.5) * 2
# b = torch.clamp(b, max=1, min=-1)
return x
class SpectralResponse(nn.Module):
# TODO: use enum instead for color mode
def __init__(self, spectral_sensitivity: str = 'b'):
assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
super().__init__()
self.spectral_sensitivity = spectral_sensitivity
if self.spectral_sensitivity == "g":
self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
def forward(self, rgb: torch.Tensor) -> torch.Tensor:
if self.spectral_sensitivity == "b":
x = rgb[:, -1:]
elif self.spectral_sensitivity == "gb":
x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
else:
assert self.spectral_sensitivity == "g"
x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
return x
class Downsample(nn.Module):
"""Antialiasing downsampling"""
def __init__(self, input_size: int, output_size: int, channels: int):
super().__init__()
if input_size % output_size == 0:
self.stride = input_size // output_size
self.grid = None
else:
self.stride = 1
step = input_size / output_size
x = torch.arange(output_size) * step
Y, X = torch.meshgrid(x, x)
grid = torch.stack((X, Y), dim=-1)
grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
grid = grid * 2 - 1
self.register_buffer("grid", grid)
sigma = 0.5 * input_size / output_size
#print(f"{input_size} -> {output_size}: sigma={sigma}")
self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
def forward(self, im: torch.Tensor):
out = self.blur(im, stride=self.stride)
if self.grid is not None:
out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
return out
class Degrade(nn.Module):
"""
Simulate the degradation of antique film
"""
def __init__(self, args:Namespace):
super().__init__()
self.srf = SpectralResponse(args.spectral_sensitivity)
self.crf = CameraResponse()
self.gaussian = None
if args.gaussian is not None and args.gaussian > 0:
self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
def forward(self, img: torch.Tensor, downsample: nn.Module = None):
if self.gaussian is not None:
img = self.gaussian(img)
if downsample is not None:
img = downsample(img)
img = self.srf(img)
img = self.crf(img)
# Note that I changed it back to 3 channels
return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img