import torch import torch.nn as nn import torch.nn.functional as F import models from models import register from utils import make_coord @register('metasr') class MetaSR(nn.Module): def __init__(self, encoder_spec): super().__init__() self.encoder = models.make(encoder_spec) imnet_spec = { 'name': 'mlp', 'args': { 'in_dim': 3, 'out_dim': self.encoder.out_dim * 9 * 3, 'hidden_list': [256] } } self.imnet = models.make(imnet_spec) def gen_feat(self, inp): self.feat = self.encoder(inp) return self.feat def query_rgb(self, coord, cell=None): feat = self.feat feat = F.unfold(feat, 3, padding=1).view( feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() feat_coord[:, :, 0] -= (2 / feat.shape[-2]) / 2 feat_coord[:, :, 1] -= (2 / feat.shape[-1]) / 2 feat_coord = feat_coord.permute(2, 0, 1) \ .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) coord_ = coord.clone() coord_[:, :, 0] -= cell[:, :, 0] / 2 coord_[:, :, 1] -= cell[:, :, 1] / 2 coord_q = (coord_ + 1e-6).clamp(-1 + 1e-6, 1 - 1e-6) q_feat = F.grid_sample( feat, coord_q.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :] \ .permute(0, 2, 1) q_coord = F.grid_sample( feat_coord, coord_q.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :] \ .permute(0, 2, 1) rel_coord = coord_ - q_coord rel_coord[:, :, 0] *= feat.shape[-2] / 2 rel_coord[:, :, 1] *= feat.shape[-1] / 2 r_rev = cell[:, :, 0] * (feat.shape[-2] / 2) inp = torch.cat([rel_coord, r_rev.unsqueeze(-1)], dim=-1) bs, q = coord.shape[:2] pred = self.imnet(inp.view(bs * q, -1)).view(bs * q, feat.shape[1], 3) pred = torch.bmm(q_feat.contiguous().view(bs * q, 1, -1), pred) pred = pred.view(bs, q, 3) return pred def forward(self, inp, coord, cell): self.gen_feat(inp) return self.query_rgb(coord, cell)