| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import Function |
| |
|
| | from ..utils import ext_loader |
| |
|
| | ext_module = ext_loader.load_ext('_ext', |
| | ['tin_shift_forward', 'tin_shift_backward']) |
| |
|
| |
|
| | class TINShiftFunction(Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, input, shift): |
| | C = input.size(2) |
| | num_segments = shift.size(1) |
| | if C // num_segments <= 0 or C % num_segments != 0: |
| | raise ValueError('C should be a multiple of num_segments, ' |
| | f'but got C={C} and num_segments={num_segments}.') |
| |
|
| | ctx.save_for_backward(shift) |
| |
|
| | out = torch.zeros_like(input) |
| | ext_module.tin_shift_forward(input, shift, out) |
| |
|
| | return out |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| |
|
| | shift = ctx.saved_tensors[0] |
| | data_grad_input = grad_output.new(*grad_output.size()).zero_() |
| | shift_grad_input = shift.new(*shift.size()).zero_() |
| | ext_module.tin_shift_backward(grad_output, shift, data_grad_input) |
| |
|
| | return data_grad_input, shift_grad_input |
| |
|
| |
|
| | tin_shift = TINShiftFunction.apply |
| |
|
| |
|
| | class TINShift(nn.Module): |
| | """Temporal Interlace Shift. |
| | |
| | Temporal Interlace shift is a differentiable temporal-wise frame shifting |
| | which is proposed in "Temporal Interlacing Network" |
| | |
| | Please refer to https://arxiv.org/abs/2001.06499 for more details. |
| | Code is modified from https://github.com/mit-han-lab/temporal-shift-module |
| | """ |
| |
|
| | def forward(self, input, shift): |
| | """Perform temporal interlace shift. |
| | |
| | Args: |
| | input (Tensor): Feature map with shape [N, num_segments, C, H * W]. |
| | shift (Tensor): Shift tensor with shape [N, num_segments]. |
| | |
| | Returns: |
| | Feature map after temporal interlace shift. |
| | """ |
| | return tin_shift(input, shift) |
| |
|