import torch import torch.nn as nn class MaskedApply(nn.Module): """ Uses an index mask to select a sbuset of the input and apply a layer to it. E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element and layers[1] will be applied to the second element. """ def __init__(self, layers): super(MaskedApply, self).__init__() self.num_layers = len(layers) self.layers = nn.ModuleList(layers) def forward(self, x, mask): # Ensure mask is a long tensor mask = mask.long() # Flatten x and mask for easier processing batch_size, seq_length, embedding_size = x.shape x_flat = x.view(-1, embedding_size) mask_flat = mask.view(-1) # Output placeholder output_flat = torch.zeros_like(x_flat) # Process each mask value for i in range(self.num_layers): # Find indices for current mask value indices = torch.where(mask_flat == i)[0] # Select relevant inputs for the current linear layer selected_inputs = torch.index_select(x_flat, 0, indices) # Apply linear layer transformed = self.layers[i](selected_inputs) # TODO: figure out why this is necessary. transformed = transformed.to(x_flat.dtype) # Place results back in the output tensor output_flat.index_copy_(0, indices, transformed) # Reshape output to original dimensions output = output_flat.view(batch_size, seq_length, embedding_size) return output