| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from .configuration_upscaler import UpscalerConfig |
|
|
|
|
| |
| |
| |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.act = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
| def forward(self, x): |
| y = self.act(self.conv1(x)) |
| y = self.conv2(y) |
| return x + y |
|
|
|
|
| class RestorationNet(nn.Module): |
| def __init__(self, in_channels=3, width=32, num_blocks=3): |
| super().__init__() |
| self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1) |
| self.blocks = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)]) |
| self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1) |
|
|
| def forward(self, lr): |
| y = self.blocks(self.in_conv(lr)) |
| y = self.out_conv(y) |
| return lr + y |
|
|
|
|
| class ESPCNUpsampler(nn.Module): |
| def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False): |
| super().__init__() |
| assert scale in (2, 3, 4) |
| self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2) |
| self.act1 = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1) |
| self.act2 = nn.ReLU(inplace=True) |
|
|
| |
| self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1) |
| self.ps = nn.PixelShuffle(scale) |
|
|
| self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None |
|
|
| def forward(self, x): |
| y = self.act1(self.conv1(x)) |
| y = self.act2(self.conv2(y)) |
| y = self.ps(self.conv3(y)) |
| if self.refine is not None: |
| y = self.refine(y) |
| return y |
|
|
|
|
| class TwoStageSR(nn.Module): |
| def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False): |
| super().__init__() |
| self.scale = scale |
| self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks) |
| self.upsampler = ESPCNUpsampler( |
| in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine |
| ) |
|
|
| def forward(self, lr): |
| lr_clean = self.restoration(lr) |
| hr_pred = self.upsampler(lr_clean) |
| return hr_pred |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class UpscalerOutput(ModelOutput): |
| sr: torch.FloatTensor |
|
|
|
|
| class UpscalerModel(PreTrainedModel): |
| config_class = UpscalerConfig |
| main_input_name = "pixel_values" |
|
|
| def __init__(self, config: UpscalerConfig): |
| super().__init__(config) |
|
|
| self.model = TwoStageSR( |
| in_channels=config.in_channels, |
| scale=config.scale, |
| width=config.width, |
| num_blocks=config.num_blocks, |
| feat1=config.feat1, |
| feat2=config.feat2, |
| use_refine=config.use_refine, |
| ) |
|
|
| self.post_init() |
|
|
| def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput: |
| """ |
| pixel_values: float tensor in [0,1], shape (B,3,H,W) |
| returns: UpscalerOutput(sr=...) |
| """ |
| sr = self.model(pixel_values) |
| return UpscalerOutput(sr=sr) |