meghanaraok commited on
Commit
bf97eb3
1 Parent(s): e4d23ac

Upload 4 files

Browse files
HiLATmain/models/__init__.py ADDED
File without changes
HiLATmain/models/modeling - Copy1.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+
4
+ import torch
5
+ from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
+ from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
7
+ from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
8
+
9
+ from hilat.models.utils import initial_code_title_vectors
10
+
11
+ logger = logging.getLogger("lwat")
12
+
13
+
14
+ class CodingModelConfig:
15
+ def __init__(self,
16
+ transformer_model_name_or_path,
17
+ transformer_tokenizer_name,
18
+ transformer_layer_update_strategy,
19
+ num_chunks,
20
+ max_seq_length,
21
+ dropout,
22
+ dropout_att,
23
+ d_model,
24
+ label_dictionary,
25
+ num_labels,
26
+ use_code_representation,
27
+ code_max_seq_length,
28
+ code_batch_size,
29
+ multi_head_att,
30
+ chunk_att,
31
+ linear_init_mean,
32
+ linear_init_std,
33
+ document_pooling_strategy,
34
+ multi_head_chunk_attention):
35
+ super(CodingModelConfig, self).__init__()
36
+ self.transformer_model_name_or_path = transformer_model_name_or_path
37
+ self.transformer_tokenizer_name = transformer_tokenizer_name
38
+ self.transformer_layer_update_strategy = transformer_layer_update_strategy
39
+ self.num_chunks = num_chunks
40
+ self.max_seq_length = max_seq_length
41
+ self.dropout = dropout
42
+ self.dropout_att = dropout_att
43
+ self.d_model = d_model
44
+ # labels_dictionary is a dataframe with columns: icd9_code, long_title
45
+ self.label_dictionary = label_dictionary
46
+ self.num_labels = num_labels
47
+ self.use_code_representation = use_code_representation
48
+ self.code_max_seq_length = code_max_seq_length
49
+ self.code_batch_size = code_batch_size
50
+ self.multi_head_att = multi_head_att
51
+ self.chunk_att = chunk_att
52
+ self.linear_init_mean = linear_init_mean
53
+ self.linear_init_std = linear_init_std
54
+ self.document_pooling_strategy = document_pooling_strategy
55
+ self.multi_head_chunk_attention = multi_head_chunk_attention
56
+
57
+
58
+ class LableWiseAttentionLayer(torch.nn.Module):
59
+ def __init__(self, coding_model_config, args):
60
+ super(LableWiseAttentionLayer, self).__init__()
61
+
62
+ self.config = coding_model_config
63
+ self.args = args
64
+
65
+ # layers
66
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
67
+ self.config.d_model, bias=False)
68
+ self.tanh = torch.nn.Tanh()
69
+ self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
70
+ self.softmax = torch.nn.Softmax(dim=1)
71
+
72
+ # Mean pooling last hidden state of code title from transformer model as the initial code vectors
73
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
74
+
75
+ def _init_linear_weights(self, mean, std):
76
+ # normalize the l1 weights
77
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
78
+ if self.l1_linear.bias is not None:
79
+ self.l1_linear.bias.data.fill_(0)
80
+ # initialize the l2
81
+ if self.config.use_code_representation:
82
+ code_vectors = initial_code_title_vectors(self.config.label_dictionary,
83
+ self.config.transformer_model_name_or_path,
84
+ self.config.transformer_tokenizer_name
85
+ if self.config.transformer_tokenizer_name
86
+ else self.config.transformer_model_name_or_path,
87
+ self.config.code_max_seq_length,
88
+ self.config.code_batch_size,
89
+ self.config.d_model,
90
+ self.args.device)
91
+
92
+ self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
93
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
94
+ if self.l2_linear.bias is not None:
95
+ self.l2_linear.bias.data.fill_(0)
96
+
97
+ def forward(self, x):
98
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
99
+ # output: (batch_size, max_seq_length, transformer_hidden_size)
100
+ # Z = Tan(WH)
101
+ l1_output = self.tanh(self.l1_linear(x))
102
+ # softmax(UZ)
103
+ # l2_linear output shape: (batch_size, max_seq_length, num_labels)
104
+ # attention_weight shape: (batch_size, num_labels, max_seq_length)
105
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
106
+ # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
107
+ attention_output = torch.matmul(attention_weight, x)
108
+
109
+ return attention_output, attention_weight
110
+
111
+ class ChunkAttentionLayer(torch.nn.Module):
112
+ def __init__(self, coding_model_config, args):
113
+ super(ChunkAttentionLayer, self).__init__()
114
+
115
+ self.config = coding_model_config
116
+ self.args = args
117
+
118
+ # layers
119
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
120
+ self.config.d_model, bias=False)
121
+ self.tanh = torch.nn.Tanh()
122
+ self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
123
+ self.softmax = torch.nn.Softmax(dim=1)
124
+
125
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
126
+
127
+ def _init_linear_weights(self, mean, std):
128
+ # initialize the l1
129
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
130
+ if self.l1_linear.bias is not None:
131
+ self.l1_linear.bias.data.fill_(0)
132
+ # initialize the l2
133
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
134
+ if self.l2_linear.bias is not None:
135
+ self.l2_linear.bias.data.fill_(0)
136
+
137
+ def forward(self, x):
138
+ # input: (batch_size, num_chunks, transformer_hidden_size)
139
+ # output: (batch_size, num_chunks, transformer_hidden_size)
140
+ # Z = Tan(WH)
141
+ l1_output = self.tanh(self.l1_linear(x))
142
+ # softmax(UZ)
143
+ # l2_linear output shape: (batch_size, num_chunks, 1)
144
+ # attention_weight shape: (batch_size, 1, num_chunks)
145
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
146
+ # attention_output shpae: (batch_size, 1, transformer_hidden_size)
147
+ attention_output = torch.matmul(attention_weight, x)
148
+ return attention_output, attention_weight
149
+
150
+
151
+ class CodingModel(torch.nn.Module):
152
+ def __init__(self, coding_model_config, args):
153
+ super(CodingModel, self).__init__()
154
+ self.coding_model_config = coding_model_config
155
+ self.args = args
156
+ # layers
157
+ self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
158
+ if isinstance(self.transformer_layer, XLNetModel):
159
+ self.transformer_layer.config.use_mems_eval = False
160
+ self.dropout = Dropout(p=self.coding_model_config.dropout)
161
+
162
+ if self.coding_model_config.multi_head_att:
163
+ # initial multi head attention according to the num_chunks
164
+ self.label_wise_attention_layer = torch.nn.ModuleList(
165
+ [LableWiseAttentionLayer(coding_model_config, args)
166
+ for _ in range(self.coding_model_config.num_chunks)])
167
+ else:
168
+ self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
169
+ self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
170
+
171
+ # initial chunk attention
172
+ if self.coding_model_config.chunk_att:
173
+ if self.coding_model_config.multi_head_chunk_attention:
174
+ self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
175
+ for _ in range(self.coding_model_config.num_labels)])
176
+ else:
177
+ self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
178
+
179
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
180
+ self.coding_model_config.num_labels)
181
+ else:
182
+ if self.coding_model_config.document_pooling_strategy == "flat":
183
+ self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
184
+ self.coding_model_config.num_labels)
185
+ else: # max or mean pooling
186
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
187
+ self.coding_model_config.num_labels)
188
+ self.sigmoid = torch.nn.Sigmoid()
189
+
190
+ if self.coding_model_config.transformer_layer_update_strategy == "no":
191
+ self.freeze_all_transformer_layers()
192
+ elif self.coding_model_config.transformer_layer_update_strategy == "last":
193
+ self.freeze_all_transformer_layers()
194
+ self.unfreeze_transformer_last_layers()
195
+
196
+ # initialize the weights of classifier
197
+ self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
198
+
199
+ def _init_linear_weights(self, mean, std):
200
+ torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
201
+
202
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
203
+ # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
204
+ # labels shape: (batch_size, num_labels)
205
+ transformer_output = []
206
+
207
+ # pass chunk by chunk into transformer layer in the batches.
208
+ # input (batch_size, sequence_length)
209
+ for i in range(self.coding_model_config.num_chunks):
210
+ l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
211
+ attention_mask=attention_mask[:, i, :],
212
+ token_type_ids=token_type_ids[:, i, :])
213
+ # output hidden state shape: (batch_size, sequence_length, hidden_size)
214
+ transformer_output.append(l1_output[0])
215
+
216
+ # transpose back chunk and batch size dimensions
217
+ transformer_output = torch.stack(transformer_output)
218
+ transformer_output = transformer_output.transpose(0, 1)
219
+ # dropout transformer output
220
+ l2_dropout = self.dropout(transformer_output)
221
+
222
+ # Label-wise attention layers
223
+ # output: (batch_size, num_chunks, num_labels, hidden_size)
224
+ attention_output = []
225
+ attention_weights = []
226
+
227
+ for i in range(self.coding_model_config.num_chunks):
228
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
229
+ if self.coding_model_config.multi_head_att:
230
+ attention_layer = self.label_wise_attention_layer[i]
231
+ else:
232
+ attention_layer = self.label_wise_attention_layer
233
+ l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
234
+ # l3_attention shape: (batch_size, num_labels, hidden_size)
235
+ # attention_weight: (batch_size, num_labels, max_seq_length)
236
+ attention_output.append(l3_attention)
237
+ attention_weights.append(attention_weight)
238
+
239
+ attention_output = torch.stack(attention_output)
240
+ attention_output = attention_output.transpose(0, 1)
241
+ attention_weights = torch.stack(attention_weights)
242
+ attention_weights = attention_weights.transpose(0, 1)
243
+
244
+ config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
245
+ config.num_labels =5
246
+ config.num_hidden_layers = 1
247
+ longformer_layer = LongformerLayer(config)
248
+ l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
249
+ attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
250
+ is_index_masked = attention_mask < 0
251
+ output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked)
252
+ l3_dropout = self.dropout_att(output[0])
253
+ l3_dropout = l3_dropout.reshape(l3_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
254
+ self.softmax = torch.nn.Softmax(dim=1)
255
+ self.l2_linear = torch.nn.Linear(self.coding_model_config.d_model, self.coding_model_config.num_labels, bias=False)
256
+ attention_weight = self.softmax(self.l2_linear(l3_dropout)).transpose(1, 2)
257
+ attention_weight = attention_weight.reshape(attention_weight.shape[0], self.coding_model_config.num_labels, self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length)
258
+ # attention_weight = attention_weight.permute(0,2,1)
259
+ l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
260
+
261
+ attention_output = []
262
+
263
+ for i in range(self.coding_model_config.num_chunks):
264
+ l3_attention = torch.matmul(attention_weight[:,:,i], l2_dropout[:,i,:])
265
+ attention_output.append(l3_attention)
266
+
267
+ attention_output = torch.stack(attention_output)
268
+ l3_dropout = self.dropout_att(attention_output)
269
+ l3_dropout = l3_dropout.transpose(0,1)
270
+
271
+
272
+ if self.coding_model_config.chunk_att:
273
+ # Chunk attention layers
274
+ # output: (batch_size, num_labels, hidden_size)
275
+ chunk_attention_output = []
276
+ chunk_attention_weights = []
277
+
278
+ for i in range(self.coding_model_config.num_labels):
279
+ if self.coding_model_config.multi_head_chunk_attention:
280
+ chunk_attention = self.chunk_attention_layer[i]
281
+ else:
282
+ chunk_attention = self.chunk_attention_layer
283
+ l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
284
+ chunk_attention_output.append(l4_chunk_attention.squeeze())
285
+ chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
286
+
287
+ chunk_attention_output = torch.stack(chunk_attention_output)
288
+ chunk_attention_output = chunk_attention_output.transpose(0, 1)
289
+ chunk_attention_weights = torch.stack(chunk_attention_weights)
290
+ chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
291
+ # output shape: (batch_size, num_labels, hidden_size)
292
+ l4_dropout = self.dropout_att(chunk_attention_output)
293
+ else:
294
+ # output shape: (batch_size, num_labels, hidden_size*num_chunks)
295
+ l4_dropout = l3_dropout.transpose(1, 2)
296
+ if self.coding_model_config.document_pooling_strategy == "flat":
297
+ # Flatten layer. concatenate representation by labels
298
+ l4_dropout = torch.flatten(l4_dropout, start_dim=2)
299
+ elif self.coding_model_config.document_pooling_strategy == "max":
300
+ l4_dropout = torch.amax(l4_dropout, 2)
301
+ elif self.coding_model_config.document_pooling_strategy == "mean":
302
+ l4_dropout = torch.mean(l4_dropout, 2)
303
+ else:
304
+ raise ValueError("Not supported pooling strategy")
305
+
306
+ # classifier layer
307
+ # each code has a binary linear formula
308
+ logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
309
+
310
+ loss_fct = BCEWithLogitsLoss()
311
+ loss = loss_fct(logits, targets)
312
+
313
+ return {
314
+ "loss": loss,
315
+ "logits": logits,
316
+ "label_attention_weights": attention_weights,
317
+ "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
318
+ }
319
+
320
+ def freeze_all_transformer_layers(self):
321
+ """
322
+ Freeze all layer weight parameters. They will not be updated during training.
323
+ """
324
+ for param in self.transformer_layer.parameters():
325
+ param.requires_grad = False
326
+
327
+ def unfreeze_all_transformer_layers(self):
328
+ """
329
+ Unfreeze all layers weight parameters. They will be updated during training.
330
+ """
331
+ for param in self.transformer_layer.parameters():
332
+ param.requires_grad = True
333
+
334
+ def unfreeze_transformer_last_layers(self):
335
+ for name, param in self.transformer_layer.named_parameters():
336
+ if "layer.11" in name or "pooler" in name:
337
+ param.requires_grad = True
HiLATmain/models/modeling.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+
4
+ import torch
5
+ from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
+ from transformers import AutoModel, XLNetModel
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ from HiLATmain.hilat.models.utils import initial_code_title_vectors
10
+
11
+ logger = logging.getLogger("lwat")
12
+
13
+
14
+ class CodingModelConfig:
15
+ def __init__(self,
16
+ transformer_model_name_or_path,
17
+ transformer_tokenizer_name,
18
+ transformer_layer_update_strategy,
19
+ num_chunks,
20
+ max_seq_length,
21
+ dropout,
22
+ dropout_att,
23
+ d_model,
24
+ label_dictionary,
25
+ num_labels,
26
+ use_code_representation,
27
+ code_max_seq_length,
28
+ code_batch_size,
29
+ multi_head_att,
30
+ chunk_att,
31
+ linear_init_mean,
32
+ linear_init_std,
33
+ document_pooling_strategy,
34
+ multi_head_chunk_attention,
35
+ num_hidden_layers):
36
+ super(CodingModelConfig, self).__init__()
37
+ self.transformer_model_name_or_path = transformer_model_name_or_path
38
+ self.transformer_tokenizer_name = transformer_tokenizer_name
39
+ self.transformer_layer_update_strategy = transformer_layer_update_strategy
40
+ self.num_chunks = num_chunks
41
+ self.max_seq_length = max_seq_length
42
+ self.dropout = dropout
43
+ self.dropout_att = dropout_att
44
+ self.d_model = d_model
45
+ # labels_dictionary is a dataframe with columns: icd9_code, long_title
46
+ self.label_dictionary = label_dictionary
47
+ self.num_labels = num_labels
48
+ self.use_code_representation = use_code_representation
49
+ self.code_max_seq_length = code_max_seq_length
50
+ self.code_batch_size = code_batch_size
51
+ self.multi_head_att = multi_head_att
52
+ self.chunk_att = chunk_att
53
+ self.linear_init_mean = linear_init_mean
54
+ self.linear_init_std = linear_init_std
55
+ self.document_pooling_strategy = document_pooling_strategy
56
+ self.multi_head_chunk_attention = multi_head_chunk_attention
57
+ self.num_hidden_layers = num_hidden_layers
58
+
59
+
60
+ class LableWiseAttentionLayer(torch.nn.Module):
61
+ def __init__(self, coding_model_config, args):
62
+ super(LableWiseAttentionLayer, self).__init__()
63
+
64
+ self.config = coding_model_config
65
+ self.args = args
66
+
67
+ # layers
68
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
69
+ self.config.d_model, bias=False)
70
+ self.tanh = torch.nn.Tanh()
71
+ self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
72
+ self.softmax = torch.nn.Softmax(dim=1)
73
+
74
+ # Mean pooling last hidden state of code title from transformer model as the initial code vectors
75
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
76
+
77
+ def _init_linear_weights(self, mean, std):
78
+ # normalize the l1 weights
79
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
80
+ if self.l1_linear.bias is not None:
81
+ self.l1_linear.bias.data.fill_(0)
82
+ # initialize the l2
83
+ if self.config.use_code_representation:
84
+ code_vectors = initial_code_title_vectors(self.config.label_dictionary,
85
+ self.config.transformer_model_name_or_path,
86
+ self.config.transformer_tokenizer_name
87
+ if self.config.transformer_tokenizer_name
88
+ else self.config.transformer_model_name_or_path,
89
+ self.config.code_max_seq_length,
90
+ self.config.code_batch_size,
91
+ self.config.d_model,
92
+ self.args.device)
93
+
94
+ self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
95
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
96
+ if self.l2_linear.bias is not None:
97
+ self.l2_linear.bias.data.fill_(0)
98
+
99
+ def forward(self, x):
100
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
101
+ # output: (batch_size, max_seq_length, transformer_hidden_size)
102
+ # Z = Tan(WH)
103
+ l1_output = self.tanh(self.l1_linear(x))
104
+ # softmax(UZ)
105
+ # l2_linear output shape: (batch_size, max_seq_length, num_labels)
106
+ # attention_weight shape: (batch_size, num_labels, max_seq_length)
107
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
108
+ # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
109
+ attention_output = torch.matmul(attention_weight, x)
110
+
111
+ return attention_output, attention_weight
112
+
113
+ class ChunkAttentionLayer(torch.nn.Module):
114
+ def __init__(self, coding_model_config, args):
115
+ super(ChunkAttentionLayer, self).__init__()
116
+
117
+ self.config = coding_model_config
118
+ self.args = args
119
+
120
+ # layers
121
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
122
+ self.config.d_model, bias=False)
123
+ self.tanh = torch.nn.Tanh()
124
+ self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
125
+ self.softmax = torch.nn.Softmax(dim=1)
126
+
127
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
128
+
129
+ def _init_linear_weights(self, mean, std):
130
+ # initialize the l1
131
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
132
+ if self.l1_linear.bias is not None:
133
+ self.l1_linear.bias.data.fill_(0)
134
+ # initialize the l2
135
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
136
+ if self.l2_linear.bias is not None:
137
+ self.l2_linear.bias.data.fill_(0)
138
+
139
+ def forward(self, x):
140
+ # input: (batch_size, num_chunks, transformer_hidden_size)
141
+ # output: (batch_size, num_chunks, transformer_hidden_size)
142
+ # Z = Tan(WH)
143
+ l1_output = self.tanh(self.l1_linear(x))
144
+ # softmax(UZ)
145
+ # l2_linear output shape: (batch_size, num_chunks, 1)
146
+ # attention_weight shape: (batch_size, 1, num_chunks)
147
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
148
+ # attention_output shpae: (batch_size, 1, transformer_hidden_size)
149
+ attention_output = torch.matmul(attention_weight, x)
150
+
151
+ return attention_output, attention_weight
152
+
153
+ # define the model class
154
+ class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
155
+ def __init__(self, coding_model_config, args, **kwargs):
156
+ super(CodingModel, self).__init__()
157
+ self.coding_model_config = coding_model_config
158
+ self.args = args
159
+ # layers
160
+ self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
161
+ if isinstance(self.transformer_layer, XLNetModel):
162
+ self.transformer_layer.config.use_mems_eval = False
163
+ self.dropout = Dropout(p=self.coding_model_config.dropout)
164
+
165
+ if self.coding_model_config.multi_head_att:
166
+ # initial multi head attention according to the num_chunks
167
+ self.label_wise_attention_layer = torch.nn.ModuleList(
168
+ [LableWiseAttentionLayer(coding_model_config, args)
169
+ for _ in range(self.coding_model_config.num_chunks)])
170
+ else:
171
+ self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
172
+ self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
173
+
174
+ # initial chunk attention
175
+ if self.coding_model_config.chunk_att:
176
+ if self.coding_model_config.multi_head_chunk_attention:
177
+ self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
178
+ for _ in range(self.coding_model_config.num_labels)])
179
+ else:
180
+ self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
181
+
182
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
183
+ self.coding_model_config.num_labels)
184
+ else:
185
+ if self.coding_model_config.document_pooling_strategy == "flat":
186
+ self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
187
+ self.coding_model_config.num_labels)
188
+ else: # max or mean pooling
189
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
190
+ self.coding_model_config.num_labels)
191
+ self.sigmoid = torch.nn.Sigmoid()
192
+
193
+ if self.coding_model_config.transformer_layer_update_strategy == "no":
194
+ self.freeze_all_transformer_layers()
195
+ elif self.coding_model_config.transformer_layer_update_strategy == "last":
196
+ self.freeze_all_transformer_layers()
197
+ self.unfreeze_transformer_last_layers()
198
+
199
+ # initialize the weights of classifier
200
+ self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
201
+
202
+ def _init_linear_weights(self, mean, std):
203
+ torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
204
+
205
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
206
+ # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
207
+ # labels shape: (batch_size, num_labels)
208
+ transformer_output = []
209
+
210
+ # pass chunk by chunk into transformer layer in the batches.
211
+ # input (batch_size, sequence_length)
212
+ for i in range(self.coding_model_config.num_chunks):
213
+ l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
214
+ attention_mask=attention_mask[:, i, :],
215
+ token_type_ids=token_type_ids[:, i, :])
216
+ # output hidden state shape: (batch_size, sequence_length, hidden_size)
217
+ transformer_output.append(l1_output[0])
218
+
219
+ # transpose back chunk and batch size dimensions
220
+ transformer_output = torch.stack(transformer_output)
221
+ transformer_output = transformer_output.transpose(0, 1)
222
+ # dropout transformer output
223
+ l2_dropout = self.dropout(transformer_output)
224
+
225
+ # Label-wise attention layers
226
+ # output: (batch_size, num_chunks, num_labels, hidden_size)
227
+ attention_output = []
228
+ attention_weights = []
229
+
230
+ for i in range(self.coding_model_config.num_chunks):
231
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
232
+ if self.coding_model_config.multi_head_att:
233
+ attention_layer = self.label_wise_attention_layer[i]
234
+ else:
235
+ attention_layer = self.label_wise_attention_layer
236
+ l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
237
+ # l3_attention shape: (batch_size, num_labels, hidden_size)
238
+ # attention_weight: (batch_size, num_labels, max_seq_length)
239
+ attention_output.append(l3_attention)
240
+ attention_weights.append(attention_weight)
241
+
242
+ attention_output = torch.stack(attention_output)
243
+ attention_output = attention_output.transpose(0, 1)
244
+ attention_weights = torch.stack(attention_weights)
245
+ attention_weights = attention_weights.transpose(0, 1)
246
+
247
+ l3_dropout = self.dropout_att(attention_output)
248
+
249
+ if self.coding_model_config.chunk_att:
250
+ # Chunk attention layers
251
+ # output: (batch_size, num_labels, hidden_size)
252
+ chunk_attention_output = []
253
+ chunk_attention_weights = []
254
+
255
+ for i in range(self.coding_model_config.num_labels):
256
+ if self.coding_model_config.multi_head_chunk_attention:
257
+ chunk_attention = self.chunk_attention_layer[i]
258
+ else:
259
+ chunk_attention = self.chunk_attention_layer
260
+ l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
261
+ chunk_attention_output.append(l4_chunk_attention.squeeze(dim=1))
262
+ chunk_attention_weights.append(l4_chunk_attention_weights.squeeze(dim=1))
263
+
264
+ chunk_attention_output = torch.stack(chunk_attention_output)
265
+ chunk_attention_output = chunk_attention_output.transpose(0, 1)
266
+ chunk_attention_weights = torch.stack(chunk_attention_weights)
267
+ chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
268
+ # output shape: (batch_size, num_labels, hidden_size)
269
+ l4_dropout = self.dropout_att(chunk_attention_output)
270
+ else:
271
+ # output shape: (batch_size, num_labels, hidden_size*num_chunks)
272
+ l4_dropout = l3_dropout.transpose(1, 2)
273
+ if self.coding_model_config.document_pooling_strategy == "flat":
274
+ # Flatten layer. concatenate representation by labels
275
+ l4_dropout = torch.flatten(l4_dropout, start_dim=2)
276
+ elif self.coding_model_config.document_pooling_strategy == "max":
277
+ l4_dropout = torch.amax(l4_dropout, 2)
278
+ elif self.coding_model_config.document_pooling_strategy == "mean":
279
+ l4_dropout = torch.mean(l4_dropout, 2)
280
+ else:
281
+ raise ValueError("Not supported pooling strategy")
282
+
283
+ # classifier layer
284
+ # each code has a binary linear formula
285
+ logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
286
+
287
+ loss_fct = BCEWithLogitsLoss()
288
+ loss = loss_fct(logits, targets)
289
+
290
+ return {
291
+ "loss": loss,
292
+ "logits": logits,
293
+ "label_attention_weights": attention_weights,
294
+ "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
295
+ }
296
+
297
+ def freeze_all_transformer_layers(self):
298
+ """
299
+ Freeze all layer weight parameters. They will not be updated during training.
300
+ """
301
+ for param in self.transformer_layer.parameters():
302
+ param.requires_grad = False
303
+
304
+ def unfreeze_all_transformer_layers(self):
305
+ """
306
+ Unfreeze all layers weight parameters. They will be updated during training.
307
+ """
308
+ for param in self.transformer_layer.parameters():
309
+ param.requires_grad = True
310
+
311
+ def unfreeze_transformer_last_layers(self):
312
+ for name, param in self.transformer_layer.named_parameters():
313
+ if "layer.11" in name or "pooler" in name:
314
+ param.requires_grad = True
HiLATmain/models/utils.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import linecache
3
+ import pickle
4
+ import random
5
+ import subprocess
6
+
7
+ import numpy as np
8
+ import redis
9
+ import torch
10
+ import logging
11
+ import ast
12
+
13
+ from datasets import Dataset
14
+ from tqdm import tqdm
15
+
16
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, roc_curve, auc
17
+ from torch.utils.data import DataLoader
18
+ from transformers import AutoModel, DataCollatorWithPadding, XLNetTokenizer, XLNetTokenizerFast, AutoTokenizer, \
19
+ XLNetModel, is_torch_tpu_available
20
+
21
+ logger = logging.getLogger("lwat")
22
+
23
+
24
+ class MimicIIIDataset(Dataset):
25
+ def __init__(self, data):
26
+ self.input_ids = data["input_ids"]
27
+ self.attention_mask = data["attention_mask"]
28
+ self.token_type_ids = data["token_type_ids"]
29
+ self.labels = data["targets"]
30
+
31
+ def __len__(self):
32
+ return len(self.input_ids)
33
+
34
+ def __getitem__(self, item):
35
+ return {
36
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
37
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
38
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long),
39
+ "targets": torch.tensor(self.labels[item], dtype=torch.float)
40
+ }
41
+
42
+ class LazyMimicIIIDataset(Dataset):
43
+ def __init__(self, filename, task, dataset_type):
44
+ print("lazy load from {}".format(filename))
45
+ self.filename = filename
46
+ self.redis = redis.Redis(unix_socket_path="/tmp/redis.sock")
47
+ self.pipe = self.redis.pipeline()
48
+ self.num_examples = 0
49
+ self.task = task
50
+ self.dataset_type = dataset_type
51
+ with open(filename, 'r') as f:
52
+ for line_num, line in enumerate(f.readlines()):
53
+ self.num_examples += 1
54
+ example = eval(line)
55
+ key = task + '_' + dataset_type + '_' + str(line_num)
56
+ input_ids = eval(example[0])
57
+ attention_mask = eval(example[1])
58
+ token_type_ids = eval(example[2])
59
+ labels = eval(example[3])
60
+ example_tuple = (input_ids, attention_mask, token_type_ids, labels)
61
+
62
+ self.pipe.set(key, pickle.dumps(example_tuple))
63
+ if line_num % 100 == 0:
64
+ self.pipe.execute()
65
+ self.pipe.execute()
66
+ if is_torch_tpu_available():
67
+ import torch_xla.core.xla_model as xm
68
+ xm.rendezvous(tag="featuresGenerated")
69
+
70
+ def __len__(self):
71
+ return self.num_examples
72
+
73
+ def __getitem__(self, item):
74
+ key = self.task + '_' + self.dataset_type + '_' + str(item)
75
+ example = pickle.loads(self.redis.get(key))
76
+
77
+ return {
78
+ "input_ids": torch.tensor(example[0], dtype=torch.long),
79
+ "attention_mask": torch.tensor(example[1], dtype=torch.float),
80
+ "token_type_ids": torch.tensor(example[2], dtype=torch.long),
81
+ "targets": torch.tensor(example[3], dtype=torch.float)
82
+ }
83
+
84
+
85
+ class ICDCodeDataset(Dataset):
86
+ def __init__(self, data):
87
+ self.input_ids = data["input_ids"]
88
+ self.attention_mask = data["attention_mask"]
89
+ self.token_type_ids = data["token_type_ids"]
90
+
91
+ def __len__(self):
92
+ return len(self.input_ids)
93
+
94
+ def __getitem__(self, item):
95
+ return {
96
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
97
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
98
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long)
99
+ }
100
+
101
+
102
+ def set_random_seed(random_seed):
103
+ random.seed(random_seed)
104
+ np.random.seed(random_seed)
105
+ torch.manual_seed(random_seed)
106
+ torch.cuda.manual_seed_all(random_seed)
107
+ torch.backends.cudnn.deterministic = True
108
+ torch.backends.cudnn.benchmark = False
109
+
110
+ def tokenize_inputs(text_list, tokenizer, max_seq_len=512):
111
+ """
112
+ Tokenizes the input text input into ids. Appends the appropriate special
113
+ characters to the end of the text to denote end of sentence. Truncate or pad
114
+ the appropriate sequence length.
115
+ """
116
+ # tokenize the text, then truncate sequence to the desired length minus 2 for
117
+ # the 2 special characters
118
+ tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:max_seq_len - 2], text_list))
119
+ # convert tokenized text into numeric ids for the appropriate LM
120
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
121
+ # get token type for token_ids_0
122
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
123
+ # append special token to end of sentence: <sep> <cls>
124
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
125
+ # attention mask
126
+ attention_mask = [[1] * len(x) for x in input_ids]
127
+
128
+ # padding to max_length
129
+ def padding_to_max(sequence, value):
130
+ padding_len = max_seq_len - len(sequence)
131
+ padding = [value] * padding_len
132
+ return sequence + padding
133
+
134
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
135
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
136
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
137
+
138
+ return input_ids, attention_mask, token_type_ids
139
+
140
+
141
+ def tokenize_dataset(tokenizer, text, labels, max_seq_len):
142
+ if (isinstance(tokenizer, XLNetTokenizer) or isinstance(tokenizer, XLNetTokenizerFast)):
143
+ data = list(map(lambda t: tokenize_inputs(t, tokenizer, max_seq_len=max_seq_len), text))
144
+ input_ids, attention_mask, token_type_ids = zip(*data)
145
+ else:
146
+ tokenizer.model_max_length = max_seq_len
147
+ input_dict = tokenizer(text, padding=True, truncation=True)
148
+ input_ids = input_dict["input_ids"]
149
+ attention_mask = input_dict["attention_mask"]
150
+ token_type_ids = input_dict["token_type_ids"]
151
+
152
+ return {
153
+ "input_ids": input_ids,
154
+ "attention_mask": attention_mask,
155
+ "token_type_ids": token_type_ids,
156
+ "targets": labels
157
+ }
158
+
159
+
160
+ def initial_code_title_vectors(label_dict, transformer_model_name, tokenizer_name, code_max_seq_length, code_batch_size,
161
+ d_model, device):
162
+ logger.info("Generate code title representations from base transformer model")
163
+ model = AutoModel.from_pretrained(transformer_model_name)
164
+ if isinstance(model, XLNetModel):
165
+ model.config.use_mems_eval = False
166
+ #
167
+ # model.config.use_mems_eval = False
168
+ # model.config.reuse_len = 0
169
+ code_titles = label_dict["long_title"].fillna("").tolist()
170
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, padding_side="right")
171
+ data = tokenizer(code_titles, padding=True, truncation=True)
172
+ code_dataset = ICDCodeDataset(data)
173
+
174
+ model.to(device)
175
+
176
+ data_collator = DataCollatorWithPadding(tokenizer, padding="max_length",
177
+ max_length=code_max_seq_length)
178
+ code_param = {"batch_size": code_batch_size, "collate_fn": data_collator}
179
+ code_dataloader = DataLoader(code_dataset, **code_param)
180
+
181
+ code_dataloader_progress_bar = tqdm(code_dataloader, unit="batches",
182
+ desc="Code title representations")
183
+ code_dataloader_progress_bar.clear()
184
+
185
+ # output shape: (num_labels, hidden_size)
186
+ initial_code_vectors = torch.zeros(len(code_dataset), d_model)
187
+
188
+ for i, data in enumerate(code_dataloader_progress_bar):
189
+ input_ids = data["input_ids"].to(device, dtype=torch.long)
190
+ attention_mask = data["attention_mask"].to(device, dtype=torch.float)
191
+ token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
192
+
193
+ # output shape: (batch_size, sequence_length, hidden_size)
194
+ output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
195
+ # Mean pooling. output shape: (batch_size, hidden_size)
196
+ mean_last_hidden_state = torch.mean(output[0], 1)
197
+ # Max pooling. output shape: (batch_size, hidden_size)
198
+ # max_last_hidden_state = torch.max((output[0] * attention_mask.unsqueeze(-1)), 1)[0]
199
+
200
+ initial_code_vectors[i * input_ids.shape[0]:(i + 1) * input_ids.shape[0], :] = mean_last_hidden_state
201
+
202
+ code_dataloader_progress_bar.refresh(True)
203
+ code_dataloader_progress_bar.clear(True)
204
+ code_dataloader_progress_bar.close()
205
+ logger.info("Code representations ready for use. Shape {}".format(initial_code_vectors.shape))
206
+ return initial_code_vectors
207
+
208
+
209
+ def normalise_labels(labels, n_label):
210
+ norm_labels = []
211
+ for label in labels:
212
+ one_hot_vector_label = [0] * n_label
213
+ one_hot_vector_label[label] = 1
214
+ norm_labels.append(one_hot_vector_label)
215
+ return np.asarray(norm_labels)
216
+
217
+
218
+ def segment_tokenize_inputs(text, tokenizer, max_seq_len, num_chunks):
219
+ # input is full text of one document
220
+ tokenized_texts = []
221
+ tokens = tokenizer.tokenize(text)
222
+ start_idx = 0
223
+ seq_len = max_seq_len - 2
224
+ for i in range(num_chunks):
225
+ if start_idx > len(tokens):
226
+ tokenized_texts.append([])
227
+ continue
228
+ tokenized_texts.append(tokens[start_idx:(start_idx + seq_len)])
229
+ start_idx += seq_len
230
+
231
+ # convert tokenized text into numeric ids for the appropriate LM
232
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
233
+ # get token type for token_ids_0
234
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
235
+ # append special token to end of sentence: <sep> <cls>
236
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
237
+ # attention mask
238
+ attention_mask = [[1] * len(x) for x in input_ids]
239
+
240
+ # padding to max_length
241
+ def padding_to_max(sequence, value):
242
+ padding_len = max_seq_len - len(sequence)
243
+ padding = [value] * padding_len
244
+ return sequence + padding
245
+
246
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
247
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
248
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
249
+
250
+ return input_ids, attention_mask, token_type_ids
251
+
252
+
253
+ def segment_tokenize_dataset(tokenizer, text, labels, max_seq_len, num_chunks):
254
+ data = list(
255
+ map(lambda t: segment_tokenize_inputs(t, tokenizer, max_seq_len, num_chunks), text))
256
+ input_ids, attention_mask, token_type_ids = zip(*data)
257
+
258
+ return {
259
+ "input_ids": input_ids,
260
+ "attention_mask": attention_mask,
261
+ "token_type_ids": token_type_ids,
262
+ "targets": labels
263
+ }
264
+
265
+
266
+ # The following functions are modified from the relevant codes of https://github.com/aehrc/LAAT
267
+ def roc_auc(true_labels, pred_probs, average="macro"):
268
+ if pred_probs.shape[0] <= 1:
269
+ return
270
+
271
+ fpr = {}
272
+ tpr = {}
273
+ if average == "macro":
274
+ # get AUC for each label individually
275
+ relevant_labels = []
276
+ auc_labels = {}
277
+ for i in range(true_labels.shape[1]):
278
+ # only if there are true positives for this label
279
+ if true_labels[:, i].sum() > 0:
280
+ fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], pred_probs[:, i])
281
+ if len(fpr[i]) > 1 and len(tpr[i]) > 1:
282
+ auc_score = auc(fpr[i], tpr[i])
283
+ if not np.isnan(auc_score):
284
+ auc_labels["auc_%d" % i] = auc_score
285
+ relevant_labels.append(i)
286
+
287
+ # macro-AUC: just average the auc scores
288
+ aucs = []
289
+ for i in relevant_labels:
290
+ aucs.append(auc_labels['auc_%d' % i])
291
+ score = np.mean(aucs)
292
+ else:
293
+ # micro-AUC: just look at each individual prediction
294
+ flat_pred = pred_probs.ravel()
295
+ fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), flat_pred)
296
+ score = auc(fpr["micro"], tpr["micro"])
297
+
298
+ return score
299
+
300
+
301
+ def union_size(x, y, axis):
302
+ return np.logical_or(x, y).sum(axis=axis).astype(float)
303
+
304
+
305
+ def intersect_size(x, y, axis):
306
+ return np.logical_and(x, y).sum(axis=axis).astype(float)
307
+
308
+
309
+ def macro_accuracy(true_labels, pred_labels):
310
+ num = intersect_size(true_labels, pred_labels, 0) / (union_size(true_labels, pred_labels, 0) + 1e-10)
311
+ return np.mean(num)
312
+
313
+
314
+ def macro_precision(true_labels, pred_labels):
315
+ num = intersect_size(true_labels, pred_labels, 0) / (pred_labels.sum(axis=0) + 1e-10)
316
+ return np.mean(num)
317
+
318
+
319
+ def macro_recall(true_labels, pred_labels):
320
+ num = intersect_size(true_labels, pred_labels, 0) / (true_labels.sum(axis=0) + 1e-10)
321
+ return np.mean(num)
322
+
323
+
324
+ def macro_f1(true_labels, pred_labels):
325
+ prec = macro_precision(true_labels, pred_labels)
326
+ rec = macro_recall(true_labels, pred_labels)
327
+ if prec + rec == 0:
328
+ f1 = 0.
329
+ else:
330
+ f1 = 2 * (prec * rec) / (prec + rec)
331
+ return prec, rec, f1
332
+
333
+
334
+ def precision_at_k(true_labels, pred_probs, ks=[1, 5, 8, 10, 15]):
335
+ # num true labels in top k predictions / k
336
+ sorted_pred = np.argsort(pred_probs)[:, ::-1]
337
+ output = []
338
+ for k in ks:
339
+ topk = sorted_pred[:, :k]
340
+
341
+ # get precision at k for each example
342
+ vals = []
343
+ for i, tk in enumerate(topk):
344
+ if len(tk) > 0:
345
+ num_true_in_top_k = true_labels[i, tk].sum()
346
+ denom = len(tk)
347
+ vals.append(num_true_in_top_k / float(denom))
348
+
349
+ output.append(np.mean(vals))
350
+ return output
351
+
352
+
353
+ def micro_recall(true_labels, pred_labels):
354
+ flat_true = true_labels.ravel()
355
+ flat_pred = pred_labels.ravel()
356
+ return intersect_size(flat_true, flat_pred, 0) / flat_true.sum(axis=0)
357
+
358
+
359
+ def micro_precision(true_labels, pred_labels):
360
+ flat_true = true_labels.ravel()
361
+ flat_pred = pred_labels.ravel()
362
+ if flat_pred.sum(axis=0) == 0:
363
+ return 0.0
364
+ return intersect_size(flat_true, flat_pred, 0) / flat_pred.sum(axis=0)
365
+
366
+
367
+ def micro_f1(true_labels, pred_labels):
368
+ prec = micro_precision(true_labels, pred_labels)
369
+ rec = micro_recall(true_labels, pred_labels)
370
+ if prec + rec == 0:
371
+ f1 = 0.
372
+ else:
373
+ f1 = 2 * (prec * rec) / (prec + rec)
374
+ return prec, rec, f1
375
+
376
+
377
+ def micro_accuracy(true_labels, pred_labels):
378
+ flat_true = true_labels.ravel()
379
+ flat_pred = pred_labels.ravel()
380
+ return intersect_size(flat_true, flat_pred, 0) / union_size(flat_true, flat_pred, 0)
381
+
382
+
383
+ def calculate_scores(true_labels, logits, average="macro", is_multilabel=True, threshold=0.5):
384
+ def sigmoid(x):
385
+ return 1 / (1 + np.exp(-x))
386
+
387
+ pred_probs = sigmoid(logits)
388
+ pred_labels = np.rint(pred_probs - threshold + 0.5)
389
+
390
+ max_size = min(len(true_labels), len(pred_labels))
391
+ true_labels = true_labels[: max_size]
392
+ pred_labels = pred_labels[: max_size]
393
+ pred_probs = pred_probs[: max_size]
394
+ p_1 = 0
395
+ p_5 = 0
396
+ p_8 = 0
397
+ p_10 = 0
398
+ p_15 = 0
399
+ if pred_probs is not None:
400
+ if not is_multilabel:
401
+ normalised_labels = normalise_labels(true_labels, len(pred_probs[0]))
402
+ auc_score = roc_auc(normalised_labels, pred_probs, average=average)
403
+ accuracy = accuracy_score(true_labels, pred_labels)
404
+ precision = precision_score(true_labels, pred_labels, average=average)
405
+ recall = recall_score(true_labels, pred_labels, average=average)
406
+ f1 = f1_score(true_labels, pred_labels, average=average)
407
+ else:
408
+ if average == "macro":
409
+ accuracy = macro_accuracy(true_labels, pred_labels) # categorical accuracy
410
+ precision, recall, f1 = macro_f1(true_labels, pred_labels)
411
+ p_ks = precision_at_k(true_labels, pred_probs, [1, 5, 8, 10, 15])
412
+ p_1 = p_ks[0]
413
+ p_5 = p_ks[1]
414
+ p_8 = p_ks[2]
415
+ p_10 = p_ks[3]
416
+ p_15 = p_ks[4]
417
+
418
+ else:
419
+ accuracy = micro_accuracy(true_labels, pred_labels)
420
+ precision, recall, f1 = micro_f1(true_labels, pred_labels)
421
+ auc_score = roc_auc(true_labels, pred_probs, average)
422
+ labelwise_f1 = f1_score(true_labels, pred_labels, average=None)
423
+ labelwise_f1 = np.array2string(labelwise_f1, separator=',')
424
+
425
+ else:
426
+ auc_score = -1
427
+
428
+ output = {"{}_precision".format(average): precision, "{}_recall".format(average): recall,
429
+ "{}_f1".format(average): f1, "{}_accuracy".format(average): accuracy,
430
+ "{}_auc".format(average): auc_score, "{}_P@1".format(average): p_1, "{}_P@5".format(average): p_5,
431
+ "{}_P@8".format(average): p_8, "{}_P@10".format(average): p_10, "{}_P@15".format(average): p_15,
432
+ "labelwise_f1": labelwise_f1
433
+ }
434
+
435
+ return output
436
+
437
+