|
import torch |
|
from torch.nn import functional as F |
|
|
|
from basicsr.utils.registry import MODEL_REGISTRY |
|
from .sr_model import SRModel |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class SwinIRModel(SRModel): |
|
|
|
def test(self): |
|
|
|
window_size = self.opt['network_g']['window_size'] |
|
scale = self.opt.get('scale', 1) |
|
mod_pad_h, mod_pad_w = 0, 0 |
|
_, _, h, w = self.lq.size() |
|
if h % window_size != 0: |
|
mod_pad_h = window_size - h % window_size |
|
if w % window_size != 0: |
|
mod_pad_w = window_size - w % window_size |
|
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') |
|
if hasattr(self, 'net_g_ema'): |
|
self.net_g_ema.eval() |
|
with torch.no_grad(): |
|
self.output = self.net_g_ema(img) |
|
else: |
|
self.net_g.eval() |
|
with torch.no_grad(): |
|
self.output = self.net_g(img) |
|
self.net_g.train() |
|
|
|
_, _, h, w = self.output.size() |
|
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] |
|
|