import torch import torch.nn as nn class MaskedApply(nn.Module): """ Uses an index mask to select a sbuset of the input and apply a layer to it. E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element and layers[1] will be applied to the second element. """ def __init__(self, layers, strict=False): super(MaskedApply, self).__init__() self.num_layers = len(layers) self.layers = nn.ModuleList(layers) self.strict = strict # Create a CPU tensor to store the maximum value found. # This will prevent the GPU being blocked while we check # whether an index is > num_layers in strict mode. self._maximum_found_cpu = torch.tensor([-1], device='cpu') self._maximum_found = torch.tensor([-1]) if torch.cuda.is_available(): self._maximum_found_cpu = self._maximum_found_cpu.pin_memory() def forward(self, x, mask): # If in strict mode, check if we previously violated the maximum found. if self.strict and self._maximum_found_cpu >= self.num_layers: raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}') # Ensure mask is a long tensor mask = mask.long() # Flatten x and mask for easier processing batch_size, seq_length, embedding_size = x.shape x_flat = x.view(-1, embedding_size) mask_flat = mask.view(-1) # Output placeholder output_flat = torch.zeros_like(x_flat) # Process each mask value for i in range(self.num_layers): # Find indices for current mask value indices = torch.where(mask_flat == i)[0] # Select relevant inputs for the current linear layer selected_inputs = torch.index_select(x_flat, 0, indices) # Apply linear layer transformed = self.layers[i](selected_inputs) # TODO: figure out why this is necessary. transformed = transformed.to(x_flat.dtype) # Place results back in the output tensor output_flat.index_copy_(0, indices, transformed) # Copy any out of range indices if self.strict: # This check is done asynchronously. self._maximum_found = max(max(mask_flat), self._maximum_found) self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True) else: indices = torch.where(mask_flat >= self.num_layers)[0] selected_inputs = torch.index_select(x_flat, 0, indices) output_flat.index_copy_(0, indices, selected_inputs) # Reshape output to original dimensions output = output_flat.view(batch_size, seq_length, embedding_size) return output