# encoding: utf-8 | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
def grid_sample(input, grid, canvas=None): | |
output = F.grid_sample(input, grid) | |
if canvas is None: | |
return output | |
else: | |
input_mask = Variable(input.data.new(input.size()).fill_(1)) | |
output_mask = F.grid_sample(input_mask, grid) | |
padded_output = output * output_mask + canvas * (1 - output_mask) | |
return padded_output | |