laurencer's picture
Step 6000
261dbc8 verified
raw
history blame
2.86 kB
import torch
import torch.nn as nn
class MaskedApply(nn.Module):
"""
Uses an index mask to select a sbuset of the input and apply a layer to it.
E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element
and layers[1] will be applied to the second element.
"""
def __init__(self, layers, strict=False):
super(MaskedApply, self).__init__()
self.num_layers = len(layers)
self.layers = nn.ModuleList(layers)
self.strict = strict
# Create a CPU tensor to store the maximum value found.
# This will prevent the GPU being blocked while we check
# whether an index is > num_layers in strict mode.
self._maximum_found_cpu = torch.tensor([-1], device='cpu')
self._maximum_found = torch.tensor([-1])
if torch.cuda.is_available():
self._maximum_found_cpu = self._maximum_found_cpu.pin_memory()
def forward(self, x, mask):
# If in strict mode, check if we previously violated the maximum found.
if self._maximum_found_cpu >= self.num_layers:
raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}')
# Ensure mask is a long tensor
mask = mask.long()
# Flatten x and mask for easier processing
batch_size, seq_length, embedding_size = x.shape
x_flat = x.view(-1, embedding_size)
mask_flat = mask.view(-1)
# Output placeholder
output_flat = torch.zeros_like(x_flat)
# Process each mask value
for i in range(self.num_layers):
# Find indices for current mask value
indices = torch.where(mask_flat == i)[0]
# Select relevant inputs for the current linear layer
selected_inputs = torch.index_select(x_flat, 0, indices)
# Apply linear layer
transformed = self.layers[i](selected_inputs)
# TODO: figure out why this is necessary.
transformed = transformed.to(x_flat.dtype)
# Place results back in the output tensor
output_flat.index_copy_(0, indices, transformed)
# Copy any out of range indices
if self.strict:
# This check is done asynchronously.
self._maximum_found = max(max(mask_flat), self._maximum_found)
self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True)
else:
indices = torch.where(mask_flat >= self.num_layers)[0]
selected_inputs = torch.index_select(x_flat, 0, indices)
output_flat.index_copy_(0, indices, selected_inputs)
# Reshape output to original dimensions
output = output_flat.view(batch_size, seq_length, embedding_size)
return output