|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class MaskedApply(nn.Module): |
|
""" |
|
Uses an index mask to select a sbuset of the input and apply a layer to it. |
|
|
|
E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element |
|
and layers[1] will be applied to the second element. |
|
""" |
|
|
|
def __init__(self, layers, strict=False): |
|
super(MaskedApply, self).__init__() |
|
self.num_layers = len(layers) |
|
self.layers = nn.ModuleList(layers) |
|
self.strict = strict |
|
|
|
|
|
|
|
|
|
self._maximum_found_cpu = torch.tensor([-1], device='cpu') |
|
self._maximum_found = torch.tensor([-1]) |
|
if torch.cuda.is_available(): |
|
self._maximum_found_cpu = self._maximum_found_cpu.pin_memory() |
|
|
|
def forward(self, x, mask): |
|
|
|
if self._maximum_found_cpu >= self.num_layers: |
|
raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}') |
|
|
|
|
|
mask = mask.long() |
|
|
|
|
|
batch_size, seq_length, embedding_size = x.shape |
|
|
|
x_flat = x.view(-1, embedding_size) |
|
mask_flat = mask.view(-1) |
|
|
|
|
|
output_flat = torch.zeros_like(x_flat) |
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
indices = torch.where(mask_flat == i)[0] |
|
|
|
|
|
selected_inputs = torch.index_select(x_flat, 0, indices) |
|
|
|
|
|
transformed = self.layers[i](selected_inputs) |
|
|
|
|
|
transformed = transformed.to(x_flat.dtype) |
|
|
|
|
|
output_flat.index_copy_(0, indices, transformed) |
|
|
|
|
|
if self.strict: |
|
|
|
self._maximum_found = max(max(mask_flat), self._maximum_found) |
|
self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True) |
|
else: |
|
indices = torch.where(mask_flat >= self.num_layers)[0] |
|
selected_inputs = torch.index_select(x_flat, 0, indices) |
|
output_flat.index_copy_(0, indices, selected_inputs) |
|
|
|
|
|
output = output_flat.view(batch_size, seq_length, embedding_size) |
|
return output |