File size: 2,876 Bytes
5cc9c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn


class MaskedApply(nn.Module):
    """
    Uses an index mask to select a sbuset of the input and apply a layer to it.

    E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element
    and layers[1] will be applied to the second element.
    """

    def __init__(self, layers, strict=False):
        super(MaskedApply, self).__init__()
        self.num_layers = len(layers)
        self.layers = nn.ModuleList(layers)
        self.strict = strict

        # Create a CPU tensor to store the maximum value found.
        # This will prevent the GPU being blocked while we check
        # whether an index is > num_layers in strict mode.
        self._maximum_found_cpu = torch.tensor([-1], device='cpu')
        self._maximum_found = torch.tensor([-1])
        if torch.cuda.is_available():
            self._maximum_found_cpu = self._maximum_found_cpu.pin_memory()
    
    def forward(self, x, mask):
        # If in strict mode, check if we previously violated the maximum found.
        if self.strict and self._maximum_found_cpu >= self.num_layers:
            raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}')

        # Ensure mask is a long tensor
        mask = mask.long()

        # Flatten x and mask for easier processing
        batch_size, seq_length, embedding_size = x.shape

        x_flat = x.view(-1, embedding_size)
        mask_flat = mask.view(-1)
        
        # Output placeholder
        output_flat = torch.zeros_like(x_flat)
        
        # Process each mask value
        for i in range(self.num_layers):
            # Find indices for current mask value
            indices = torch.where(mask_flat == i)[0]

            # Select relevant inputs for the current linear layer
            selected_inputs = torch.index_select(x_flat, 0, indices)

            # Apply linear layer
            transformed = self.layers[i](selected_inputs)

            # TODO: figure out why this is necessary.
            transformed = transformed.to(x_flat.dtype)
            
            # Place results back in the output tensor
            output_flat.index_copy_(0, indices, transformed)

        # Copy any out of range indices
        if self.strict:
            # This check is done asynchronously.
            self._maximum_found = max(max(mask_flat), self._maximum_found)
            self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True)
        else:
            indices = torch.where(mask_flat >= self.num_layers)[0]
            selected_inputs = torch.index_select(x_flat, 0, indices)
            output_flat.index_copy_(0, indices, selected_inputs)
        
        # Reshape output to original dimensions
        output = output_flat.view(batch_size, seq_length, embedding_size)
        return output