File size: 2,034 Bytes
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

#############################
#   Imports
#############################

# Python modules
from typing import List
from random import randint

# Remote modules
import torch

# Local modules
from utils import Head_Mask

#############################
#   Constants
#############################

#############################
#   Stuff
#############################

def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None):
    mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads))
    if head_mask_type == Head_Mask.RANDOM:
        for i in range(config.encoder_layers):
            rand_idx = randint(0, config.encoder_attention_heads-1)
            mask_heads[i, rand_idx] = 1
    elif head_mask_type == Head_Mask.NONE:
        mask_heads[:, :] = 1
    elif head_mask_type == Head_Mask.ALL:
        pass
    elif head_mask_type == Head_Mask.SPECIFIC:
        if specific_heads:
            for layer_i in range(len(mask_heads)):
                specific_head = specific_heads[layer_i] - 1
                mask_heads[layer_i][specific_head] = 1
        else:
            mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
                [1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
                [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
                [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
                [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
                [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]])
    else:
        raise NotImplementedError()
    return mask_heads.tolist()