|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import models |
|
from models import register |
|
from utils import make_coord |
|
|
|
|
|
@register('liif-sampler') |
|
class LIIF_Sampler(nn.Module): |
|
|
|
def __init__(self, imnet_spec, in_dim, out_dim): |
|
super().__init__() |
|
self.imnet = models.make(imnet_spec, args={'in_dim': in_dim+4, 'out_dim': out_dim}) |
|
|
|
def make_inp(self, feat, coord, cell): |
|
feat_coord = make_coord(feat.shape[-2:], flatten=False, device=feat.device)\ |
|
.permute(2,0,1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) |
|
q_feat = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest', |
|
align_corners=False)[:,:,0,:].permute(0,2,1) |
|
q_coord = F.grid_sample(feat_coord, coord.flip(-1).unsqueeze(1), mode='nearest', |
|
align_corners=False)[:,:,0,:].permute(0,2,1) |
|
|
|
rel_coord = coord - q_coord |
|
rel_coord[:,:,0] *= feat.shape[-2] |
|
rel_coord[:,:,1] *= feat.shape[-1] |
|
|
|
rel_cell = cell.clone() |
|
rel_cell[:,:,0] *= feat.shape[-2] |
|
rel_cell[:,:,1] *= feat.shape[-1] |
|
|
|
inp = torch.cat([q_feat, rel_coord, rel_cell], dim=-1) |
|
return inp |
|
|
|
def forward(self, x, coord=None, cell=None): |
|
if coord is not None: |
|
x = self.make_inp(x, coord, cell) |
|
x = self.imnet(x) |
|
return x |