Question Answering
File size: 5,865 Bytes
4743e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch

class Dataset(torch.utils.data.Dataset):
    """
    This class loads and preprocesses the given text data
    """
    def __init__(self, paths, tokenizer):
        """
        This function initialises the object. It takes the given paths and tokeniser.
        """
        # the last file might not have 10000 samples, which makes it difficult to get the total length of the ds
        self.paths = paths[:len(paths)-1]
        self.tokenizer = tokenizer
        self.data = self.read_file(self.paths[0])
        self.current_file = 1
        self.remaining = len(self.data)
        self.encodings = self.get_encodings(self.data)

    def __len__(self):
        """
        returns the lenght of the ds
        """
        return 10000*len(self.paths)
    
    def read_file(self, path):
        """
        reads a given file
        """
        with open(path, 'r', encoding='utf-8') as f:
            lines = f.read().split('\n')
        return lines

    def get_encodings(self, lines_all):
        """
        Creates encodings for a given text input
        """
        # tokenise all text 
        batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True)

        # Ground Truth
        labels = torch.tensor(batch['input_ids'])
        # Attention Masks
        mask = torch.tensor(batch['attention_mask'])

        # Input to be masked
        input_ids = labels.detach().clone()
        rand = torch.rand(input_ids.shape)

        # with a probability of 15%, mask a given word, leave out CLS, SEP and PAD
        mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3)
        # assign token 4 (=MASK)
        input_ids[mask_arr] = 4
        
        return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels}

    def __getitem__(self, i):
        """
        returns item i
        Note: do not use shuffling for this dataset
        """
        # if we have looked at all items in the file - take next
        if self.remaining == 0:
            self.data = self.read_file(self.paths[self.current_file])
            self.current_file += 1
            self.remaining = len(self.data)
            self.encodings = self.get_encodings(self.data)
        
        # if we are at the end of the dataset, start over again
        if self.current_file == len(self.paths):
            self.current_file = 0
                 
        self.remaining -= 1    
        return {key: tensor[i%10000] for key, tensor in self.encodings.items()}  

def test_model(model, optim, test_ds_loader, device):
    """
    This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change,
    and whether any parameters become NaN or Inf
    :param model: model to be tested
    :param optim: optimiser used for training
    :param test_ds_loader: dataset to perform the forward pass on
    :param device: current device
    :raises Exception: if any of the above conditions are not met
    """
    ## Check if non-frozen parameters changed and frozen ones did not

    # get initial parameters to check against
    params = [ np for np in model.named_parameters() if np[1].requires_grad ]
    initial_params = [ (name, p.clone()) for (name, p) in params ]

    params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ]
    initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ]

    optim.zero_grad()

    # get data
    batch = next(iter(test_ds_loader))

    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)

    # forward pass and backpropagation
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs.loss
    loss.backward()
    optim.step()

    # check if variables have changed
    for (_, p0), (name, p1) in zip(initial_params, params):
        # check different than initial
        try:
            assert not torch.equal(p0.to(device), p1.to(device))
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='did not change!'
                )
            )
        # check not NaN
        try:
            assert not torch.isnan(p1).byte().any()
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='is NaN!'
                )
            )
        # check finite
        try:
            assert torch.isfinite(p1).byte().all()
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='is Inf!'
                )
            )
        
    # check that frozen weights have not changed
    for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
        # should be the same
        try:
            assert torch.equal(p0.to(device), p1.to(device))
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='changed!' 
                )
            )
        # check not NaN
        try:
            assert not torch.isnan(p1).byte().any()
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='is NaN!'
                )
            )
            
        # check finite numbers
        try:
            assert torch.isfinite(p1).byte().all()
        except AssertionError:
            raise Exception(
            "{var_name} {msg}".format(
                var_name=name, 
                msg='is Inf!'
                )
            )
    print("Passed")