meghanaraok commited on
Commit
e4d23ac
1 Parent(s): be58120

Upload 3 files

Browse files
LongHiLATmain/models/__init__.py ADDED
File without changes
LongHiLATmain/models/modeling.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+
4
+ import torch
5
+ from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
+ from transformers import AutoModel, XLNetModel, LongformerConfig
7
+ from transformers.models.longformer.modeling_longformer import LongformerEncoder
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+ from LongHiLATmain.hilat.models.utils import initial_code_title_vectors
10
+
11
+ logger = logging.getLogger("lwat")
12
+
13
+
14
+ class CodingModelConfig:
15
+ def __init__(self,
16
+ transformer_model_name_or_path,
17
+ transformer_tokenizer_name,
18
+ transformer_layer_update_strategy,
19
+ num_chunks,
20
+ max_seq_length,
21
+ dropout,
22
+ dropout_att,
23
+ d_model,
24
+ label_dictionary,
25
+ num_labels,
26
+ use_code_representation,
27
+ code_max_seq_length,
28
+ code_batch_size,
29
+ multi_head_att,
30
+ chunk_att,
31
+ linear_init_mean,
32
+ linear_init_std,
33
+ document_pooling_strategy,
34
+ multi_head_chunk_attention,
35
+ num_hidden_layers):
36
+ super(CodingModelConfig, self).__init__()
37
+ self.transformer_model_name_or_path = transformer_model_name_or_path
38
+ self.transformer_tokenizer_name = transformer_tokenizer_name
39
+ self.transformer_layer_update_strategy = transformer_layer_update_strategy
40
+ self.num_chunks = num_chunks
41
+ self.max_seq_length = max_seq_length
42
+ self.dropout = dropout
43
+ self.dropout_att = dropout_att
44
+ self.d_model = d_model
45
+ # labels_dictionary is a dataframe with columns: icd9_code, long_title
46
+ self.label_dictionary = label_dictionary
47
+ self.num_labels = num_labels
48
+ self.use_code_representation = use_code_representation
49
+ self.code_max_seq_length = code_max_seq_length
50
+ self.code_batch_size = code_batch_size
51
+ self.multi_head_att = multi_head_att
52
+ self.chunk_att = chunk_att
53
+ self.linear_init_mean = linear_init_mean
54
+ self.linear_init_std = linear_init_std
55
+ self.document_pooling_strategy = document_pooling_strategy
56
+ self.multi_head_chunk_attention = multi_head_chunk_attention
57
+ self.num_hidden_layers = num_hidden_layers
58
+
59
+
60
+ class LableWiseAttentionLayer(torch.nn.Module):
61
+ def __init__(self, coding_model_config, args):
62
+ super(LableWiseAttentionLayer, self).__init__()
63
+
64
+ self.config = coding_model_config
65
+ self.args = args
66
+
67
+ # layers
68
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
69
+ self.config.d_model, bias=False)
70
+ self.tanh = torch.nn.Tanh()
71
+ self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
72
+ self.softmax = torch.nn.Softmax(dim=1)
73
+
74
+ # Mean pooling last hidden state of code title from transformer model as the initial code vectors
75
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
76
+
77
+ def _init_linear_weights(self, mean, std):
78
+ # normalize the l1 weights
79
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
80
+ if self.l1_linear.bias is not None:
81
+ self.l1_linear.bias.data.fill_(0)
82
+ # initialize the l2
83
+ if self.config.use_code_representation:
84
+ code_vectors = initial_code_title_vectors(self.config.label_dictionary,
85
+ self.config.transformer_model_name_or_path,
86
+ self.config.transformer_tokenizer_name
87
+ if self.config.transformer_tokenizer_name
88
+ else self.config.transformer_model_name_or_path,
89
+ self.config.code_max_seq_length,
90
+ self.config.code_batch_size,
91
+ self.config.d_model,
92
+ self.args.device)
93
+
94
+ self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
95
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
96
+ if self.l2_linear.bias is not None:
97
+ self.l2_linear.bias.data.fill_(0)
98
+
99
+ def forward(self, x):
100
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
101
+ # output: (batch_size, max_seq_length, transformer_hidden_size)
102
+ # Z = Tan(WH)
103
+ l1_output = self.tanh(self.l1_linear(x))
104
+ # softmax(UZ)
105
+ # l2_linear output shape: (batch_size, max_seq_length, num_labels)
106
+ # attention_weight shape: (batch_size, num_labels, max_seq_length)
107
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
108
+ # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
109
+ attention_output = torch.matmul(attention_weight, x)
110
+
111
+ return attention_output, attention_weight
112
+
113
+ class ChunkAttentionLayer(torch.nn.Module):
114
+ def __init__(self, coding_model_config, args):
115
+ super(ChunkAttentionLayer, self).__init__()
116
+
117
+ self.config = coding_model_config
118
+ self.args = args
119
+
120
+ # layers
121
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
122
+ self.config.d_model, bias=False)
123
+ self.tanh = torch.nn.Tanh()
124
+ self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
125
+ self.softmax = torch.nn.Softmax(dim=1)
126
+
127
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
128
+
129
+ def _init_linear_weights(self, mean, std):
130
+ # initialize the l1
131
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
132
+ if self.l1_linear.bias is not None:
133
+ self.l1_linear.bias.data.fill_(0)
134
+ # initialize the l2
135
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
136
+ if self.l2_linear.bias is not None:
137
+ self.l2_linear.bias.data.fill_(0)
138
+
139
+ def forward(self, x):
140
+ # input: (batch_size, num_chunks, transformer_hidden_size)
141
+ # output: (batch_size, num_chunks, transformer_hidden_size)
142
+ # Z = Tan(WH)
143
+ l1_output = self.tanh(self.l1_linear(x))
144
+ # softmax(UZ)
145
+ # l2_linear output shape: (batch_size, num_chunks, 1)
146
+ # attention_weight shape: (batch_size, 1, num_chunks)
147
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
148
+ # attention_output shpae: (batch_size, 1, transformer_hidden_size)
149
+ attention_output = torch.matmul(attention_weight, x)
150
+
151
+ return attention_output, attention_weight
152
+
153
+ # define the model class
154
+ class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
155
+ def __init__(self, coding_model_config, args, **kwargs):
156
+ super(CodingModel, self).__init__()
157
+ self.coding_model_config = coding_model_config
158
+ self.args = args
159
+ # layers
160
+ self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
161
+ if isinstance(self.transformer_layer, XLNetModel):
162
+ self.transformer_layer.config.use_mems_eval = False
163
+ self.dropout = Dropout(p=self.coding_model_config.dropout)
164
+
165
+ if self.coding_model_config.multi_head_att:
166
+ # initial multi head attention according to the num_chunks
167
+ self.label_wise_attention_layer = torch.nn.ModuleList(
168
+ [LableWiseAttentionLayer(coding_model_config, args)
169
+ for _ in range(self.coding_model_config.num_chunks)])
170
+ else:
171
+ self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
172
+ self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
173
+
174
+ # initial chunk attention
175
+ if self.coding_model_config.chunk_att:
176
+ if self.coding_model_config.multi_head_chunk_attention:
177
+ self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
178
+ for _ in range(self.coding_model_config.num_labels)])
179
+ else:
180
+ self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
181
+
182
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
183
+ self.coding_model_config.num_labels)
184
+ else:
185
+ if self.coding_model_config.document_pooling_strategy == "flat":
186
+ self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
187
+ self.coding_model_config.num_labels)
188
+ else: # max or mean pooling
189
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
190
+ self.coding_model_config.num_labels)
191
+ self.sigmoid = torch.nn.Sigmoid()
192
+
193
+ if self.coding_model_config.transformer_layer_update_strategy == "no":
194
+ self.freeze_all_transformer_layers()
195
+ elif self.coding_model_config.transformer_layer_update_strategy == "last":
196
+ self.freeze_all_transformer_layers()
197
+ self.unfreeze_transformer_last_layers()
198
+
199
+ # initialize the weights of classifier
200
+ self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
201
+
202
+ def _init_linear_weights(self, mean, std):
203
+ torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
204
+
205
+ def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
206
+ # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
207
+ # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
208
+ # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
209
+ if attention_mask is not None:
210
+ attention_mask = attention_mask * (global_attention_mask + 1)
211
+ else:
212
+ # simply use `global_attention_mask` as `attention_mask`
213
+ # if no `attention_mask` is given
214
+ attention_mask = global_attention_mask + 1
215
+ return attention_mask
216
+
217
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
218
+ # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
219
+ # labels shape: (batch_size, num_labels)
220
+ transformer_output = []
221
+
222
+ # pass chunk by chunk into transformer layer in the batches.
223
+ # input (batch_size, sequence_length)
224
+ for i in range(self.coding_model_config.num_chunks):
225
+ l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
226
+ attention_mask=attention_mask[:, i, :],
227
+ token_type_ids=token_type_ids[:, i, :])
228
+ # output hidden state shape: (batch_size, sequence_length, hidden_size)
229
+ transformer_output.append(l1_output[0])
230
+
231
+ # transpose back chunk and batch size dimensions
232
+ transformer_output = torch.stack(transformer_output)
233
+ transformer_output = transformer_output.transpose(0, 1)
234
+ # dropout transformer output
235
+ l2_dropout = self.dropout(transformer_output)
236
+
237
+ config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
238
+ config.num_labels =5
239
+ config.num_hidden_layers = 2
240
+ # self.coding_model_config.num_hidden_layers
241
+ config.hidden_size = self.coding_model_config.d_model
242
+ config.attention_window = [512, 512]
243
+ longformer_layer = LongformerEncoder(config)
244
+ # longformer_layer = longformer_layer(config)
245
+ longformer_layer = longformer_layer.to(torch.device("cuda:0"))
246
+ l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
247
+ attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
248
+ # is_index_masked = attention_mask < 0
249
+
250
+ global_attention_mask = torch.zeros_like(attention_mask)
251
+ # global attention on cls token
252
+ global_attention_positions = [0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4095]
253
+ global_attention_mask[:, global_attention_positions] = 1
254
+
255
+ if global_attention_mask is not None:
256
+ attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
257
+ output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True)
258
+ l2_dropout = self.dropout_att(output[0])
259
+ l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
260
+
261
+ # Label-wise attention layers
262
+ # output: (batch_size, num_chunks, num_labels, hidden_size)
263
+ attention_output = []
264
+ attention_weights = []
265
+
266
+ for i in range(self.coding_model_config.num_chunks):
267
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
268
+ if self.coding_model_config.multi_head_att:
269
+ attention_layer = self.label_wise_attention_layer[i]
270
+ else:
271
+ attention_layer = self.label_wise_attention_layer
272
+ l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
273
+ # l3_attention shape: (batch_size, num_labels, hidden_size)
274
+ # attention_weight: (batch_size, num_labels, max_seq_length)
275
+ attention_output.append(l3_attention)
276
+ attention_weights.append(attention_weight)
277
+
278
+ attention_output = torch.stack(attention_output)
279
+ attention_output = attention_output.transpose(0, 1)
280
+ attention_weights = torch.stack(attention_weights)
281
+ attention_weights = attention_weights.transpose(0, 1)
282
+
283
+ l3_dropout = self.dropout_att(attention_output)
284
+
285
+ if self.coding_model_config.chunk_att:
286
+ # Chunk attention layers
287
+ # output: (batch_size, num_labels, hidden_size)
288
+ chunk_attention_output = []
289
+ chunk_attention_weights = []
290
+
291
+ for i in range(self.coding_model_config.num_labels):
292
+ if self.coding_model_config.multi_head_chunk_attention:
293
+ chunk_attention = self.chunk_attention_layer[i]
294
+ else:
295
+ chunk_attention = self.chunk_attention_layer
296
+ l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
297
+ chunk_attention_output.append(l4_chunk_attention.squeeze(dim=1))
298
+ chunk_attention_weights.append(l4_chunk_attention_weights.squeeze(dim=1))
299
+
300
+ chunk_attention_output = torch.stack(chunk_attention_output)
301
+ chunk_attention_output = chunk_attention_output.transpose(0, 1)
302
+ chunk_attention_weights = torch.stack(chunk_attention_weights)
303
+ chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
304
+ # output shape: (batch_size, num_labels, hidden_size)
305
+ l4_dropout = self.dropout_att(chunk_attention_output)
306
+ else:
307
+ # output shape: (batch_size, num_labels, hidden_size*num_chunks)
308
+ l4_dropout = l3_dropout.transpose(1, 2)
309
+ if self.coding_model_config.document_pooling_strategy == "flat":
310
+ # Flatten layer. concatenate representation by labels
311
+ l4_dropout = torch.flatten(l4_dropout, start_dim=2)
312
+ elif self.coding_model_config.document_pooling_strategy == "max":
313
+ l4_dropout = torch.amax(l4_dropout, 2)
314
+ elif self.coding_model_config.document_pooling_strategy == "mean":
315
+ l4_dropout = torch.mean(l4_dropout, 2)
316
+ else:
317
+ raise ValueError("Not supported pooling strategy")
318
+
319
+ # classifier layer
320
+ # each code has a binary linear formula
321
+ logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
322
+
323
+ loss_fct = BCEWithLogitsLoss()
324
+ loss = loss_fct(logits, targets)
325
+
326
+ return {
327
+ "loss": loss,
328
+ "logits": logits,
329
+ "label_attention_weights": attention_weights,
330
+ "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
331
+ }
332
+
333
+ def freeze_all_transformer_layers(self):
334
+ """
335
+ Freeze all layer weight parameters. They will not be updated during training.
336
+ """
337
+ for param in self.transformer_layer.parameters():
338
+ param.requires_grad = False
339
+
340
+ def unfreeze_all_transformer_layers(self):
341
+ """
342
+ Unfreeze all layers weight parameters. They will be updated during training.
343
+ """
344
+ for param in self.transformer_layer.parameters():
345
+ param.requires_grad = True
346
+
347
+ def unfreeze_transformer_last_layers(self):
348
+ for name, param in self.transformer_layer.named_parameters():
349
+ if "layer.11" in name or "pooler" in name:
350
+ param.requires_grad = True
LongHiLATmain/models/utils.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import linecache
3
+ import pickle
4
+ import random
5
+ import subprocess
6
+
7
+ import numpy as np
8
+ import redis
9
+ import torch
10
+ import logging
11
+ import ast
12
+
13
+ from datasets import Dataset
14
+ from tqdm import tqdm
15
+
16
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, roc_curve, auc
17
+ from torch.utils.data import DataLoader
18
+ from transformers import AutoModel, DataCollatorWithPadding, XLNetTokenizer, XLNetTokenizerFast, AutoTokenizer, \
19
+ XLNetModel, is_torch_tpu_available
20
+
21
+ logger = logging.getLogger("lwat")
22
+
23
+
24
+ class MimicIIIDataset(Dataset):
25
+ def __init__(self, data):
26
+ self.input_ids = data["input_ids"]
27
+ self.attention_mask = data["attention_mask"]
28
+ self.token_type_ids = data["token_type_ids"]
29
+ self.labels = data["targets"]
30
+
31
+ def __len__(self):
32
+ return len(self.input_ids)
33
+
34
+ def __getitem__(self, item):
35
+ return {
36
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
37
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
38
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long),
39
+ "targets": torch.tensor(self.labels[item], dtype=torch.float)
40
+ }
41
+
42
+ class LazyMimicIIIDataset(Dataset):
43
+ def __init__(self, filename, task, dataset_type):
44
+ print("lazy load from {}".format(filename))
45
+ self.filename = filename
46
+ self.redis = redis.Redis(unix_socket_path="/tmp/redis.sock")
47
+ self.pipe = self.redis.pipeline()
48
+ self.num_examples = 0
49
+ self.task = task
50
+ self.dataset_type = dataset_type
51
+ with open(filename, 'r') as f:
52
+ for line_num, line in enumerate(f.readlines()):
53
+ self.num_examples += 1
54
+ example = eval(line)
55
+ key = task + '_' + dataset_type + '_' + str(line_num)
56
+ input_ids = eval(example[0])
57
+ attention_mask = eval(example[1])
58
+ token_type_ids = eval(example[2])
59
+ labels = eval(example[3])
60
+ example_tuple = (input_ids, attention_mask, token_type_ids, labels)
61
+
62
+ self.pipe.set(key, pickle.dumps(example_tuple))
63
+ if line_num % 100 == 0:
64
+ self.pipe.execute()
65
+ self.pipe.execute()
66
+ if is_torch_tpu_available():
67
+ import torch_xla.core.xla_model as xm
68
+ xm.rendezvous(tag="featuresGenerated")
69
+
70
+ def __len__(self):
71
+ return self.num_examples
72
+
73
+ def __getitem__(self, item):
74
+ key = self.task + '_' + self.dataset_type + '_' + str(item)
75
+ example = pickle.loads(self.redis.get(key))
76
+
77
+ return {
78
+ "input_ids": torch.tensor(example[0], dtype=torch.long),
79
+ "attention_mask": torch.tensor(example[1], dtype=torch.float),
80
+ "token_type_ids": torch.tensor(example[2], dtype=torch.long),
81
+ "targets": torch.tensor(example[3], dtype=torch.float)
82
+ }
83
+
84
+
85
+ class ICDCodeDataset(Dataset):
86
+ def __init__(self, data):
87
+ self.input_ids = data["input_ids"]
88
+ self.attention_mask = data["attention_mask"]
89
+ self.token_type_ids = data["token_type_ids"]
90
+
91
+ def __len__(self):
92
+ return len(self.input_ids)
93
+
94
+ def __getitem__(self, item):
95
+ return {
96
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
97
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
98
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long)
99
+ }
100
+
101
+
102
+ def set_random_seed(random_seed):
103
+ random.seed(random_seed)
104
+ np.random.seed(random_seed)
105
+ torch.manual_seed(random_seed)
106
+ torch.cuda.manual_seed_all(random_seed)
107
+ torch.backends.cudnn.deterministic = True
108
+ torch.backends.cudnn.benchmark = False
109
+
110
+ def tokenize_inputs(text_list, tokenizer, max_seq_len=512):
111
+ """
112
+ Tokenizes the input text input into ids. Appends the appropriate special
113
+ characters to the end of the text to denote end of sentence. Truncate or pad
114
+ the appropriate sequence length.
115
+ """
116
+ # tokenize the text, then truncate sequence to the desired length minus 2 for
117
+ # the 2 special characters
118
+ tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:max_seq_len - 2], text_list))
119
+ # convert tokenized text into numeric ids for the appropriate LM
120
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
121
+ # get token type for token_ids_0
122
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
123
+ # append special token to end of sentence: <sep> <cls>
124
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
125
+ # attention mask
126
+ attention_mask = [[1] * len(x) for x in input_ids]
127
+
128
+ # padding to max_length
129
+ def padding_to_max(sequence, value):
130
+ padding_len = max_seq_len - len(sequence)
131
+ padding = [value] * padding_len
132
+ return sequence + padding
133
+
134
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
135
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
136
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
137
+
138
+ return input_ids, attention_mask, token_type_ids
139
+
140
+
141
+ def tokenize_dataset(tokenizer, text, labels, max_seq_len):
142
+ if (isinstance(tokenizer, XLNetTokenizer) or isinstance(tokenizer, XLNetTokenizerFast)):
143
+ data = list(map(lambda t: tokenize_inputs(t, tokenizer, max_seq_len=max_seq_len), text))
144
+ input_ids, attention_mask, token_type_ids = zip(*data)
145
+ else:
146
+ tokenizer.model_max_length = max_seq_len
147
+ input_dict = tokenizer(text, padding=True, truncation=True)
148
+ input_ids = input_dict["input_ids"]
149
+ attention_mask = input_dict["attention_mask"]
150
+ token_type_ids = input_dict["token_type_ids"]
151
+
152
+ return {
153
+ "input_ids": input_ids,
154
+ "attention_mask": attention_mask,
155
+ "token_type_ids": token_type_ids,
156
+ "targets": labels
157
+ }
158
+
159
+
160
+ def initial_code_title_vectors(label_dict, transformer_model_name, tokenizer_name, code_max_seq_length, code_batch_size,
161
+ d_model, device):
162
+ logger.info("Generate code title representations from base transformer model")
163
+ model = AutoModel.from_pretrained(transformer_model_name)
164
+ if isinstance(model, XLNetModel):
165
+ model.config.use_mems_eval = False
166
+ #
167
+ # model.config.use_mems_eval = False
168
+ # model.config.reuse_len = 0
169
+ code_titles = label_dict["long_title"].fillna("").tolist()
170
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, padding_side="right")
171
+ data = tokenizer(code_titles, padding=True, truncation=True)
172
+ code_dataset = ICDCodeDataset(data)
173
+
174
+ model.to(device)
175
+
176
+ data_collator = DataCollatorWithPadding(tokenizer, padding="max_length",
177
+ max_length=code_max_seq_length)
178
+ code_param = {"batch_size": code_batch_size, "collate_fn": data_collator}
179
+ code_dataloader = DataLoader(code_dataset, **code_param)
180
+
181
+ code_dataloader_progress_bar = tqdm(code_dataloader, unit="batches",
182
+ desc="Code title representations")
183
+ code_dataloader_progress_bar.clear()
184
+
185
+ # output shape: (num_labels, hidden_size)
186
+ initial_code_vectors = torch.zeros(len(code_dataset), d_model)
187
+
188
+ for i, data in enumerate(code_dataloader_progress_bar):
189
+ input_ids = data["input_ids"].to(device, dtype=torch.long)
190
+ attention_mask = data["attention_mask"].to(device, dtype=torch.float)
191
+ token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
192
+
193
+ # output shape: (batch_size, sequence_length, hidden_size)
194
+ output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
195
+ # Mean pooling. output shape: (batch_size, hidden_size)
196
+ mean_last_hidden_state = torch.mean(output[0], 1)
197
+ # Max pooling. output shape: (batch_size, hidden_size)
198
+ # max_last_hidden_state = torch.max((output[0] * attention_mask.unsqueeze(-1)), 1)[0]
199
+
200
+ initial_code_vectors[i * input_ids.shape[0]:(i + 1) * input_ids.shape[0], :] = mean_last_hidden_state
201
+
202
+ code_dataloader_progress_bar.refresh(True)
203
+ code_dataloader_progress_bar.clear(True)
204
+ code_dataloader_progress_bar.close()
205
+ logger.info("Code representations ready for use. Shape {}".format(initial_code_vectors.shape))
206
+ return initial_code_vectors
207
+
208
+
209
+ def normalise_labels(labels, n_label):
210
+ norm_labels = []
211
+ for label in labels:
212
+ one_hot_vector_label = [0] * n_label
213
+ one_hot_vector_label[label] = 1
214
+ norm_labels.append(one_hot_vector_label)
215
+ return np.asarray(norm_labels)
216
+
217
+
218
+ def segment_tokenize_inputs(text, tokenizer, max_seq_len, num_chunks):
219
+ # input is full text of one document
220
+ tokenized_texts = []
221
+ tokens = tokenizer.tokenize(text)
222
+ start_idx = 0
223
+ seq_len = max_seq_len - 2
224
+ for i in range(num_chunks):
225
+ if start_idx > len(tokens):
226
+ tokenized_texts.append([])
227
+ continue
228
+ tokenized_texts.append(tokens[start_idx:(start_idx + seq_len)])
229
+ start_idx += seq_len
230
+
231
+ # convert tokenized text into numeric ids for the appropriate LM
232
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
233
+ # get token type for token_ids_0
234
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
235
+ # append special token to end of sentence: <sep> <cls>
236
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
237
+ # attention mask
238
+ attention_mask = [[1] * len(x) for x in input_ids]
239
+
240
+ # padding to max_length
241
+ def padding_to_max(sequence, value):
242
+ padding_len = max_seq_len - len(sequence)
243
+ padding = [value] * padding_len
244
+ return sequence + padding
245
+
246
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
247
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
248
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
249
+
250
+ return input_ids, attention_mask, token_type_ids
251
+
252
+
253
+ def segment_tokenize_dataset(tokenizer, text, labels, max_seq_len, num_chunks):
254
+ data = list(
255
+ map(lambda t: segment_tokenize_inputs(t, tokenizer, max_seq_len, num_chunks), text))
256
+ input_ids, attention_mask, token_type_ids = zip(*data)
257
+
258
+ return {
259
+ "input_ids": input_ids,
260
+ "attention_mask": attention_mask,
261
+ "token_type_ids": token_type_ids,
262
+ "targets": labels
263
+ }
264
+
265
+
266
+ # The following functions are modified from the relevant codes of https://github.com/aehrc/LAAT
267
+ def roc_auc(true_labels, pred_probs, average="macro"):
268
+ if pred_probs.shape[0] <= 1:
269
+ return
270
+
271
+ fpr = {}
272
+ tpr = {}
273
+ if average == "macro":
274
+ # get AUC for each label individually
275
+ relevant_labels = []
276
+ auc_labels = {}
277
+ for i in range(true_labels.shape[1]):
278
+ # only if there are true positives for this label
279
+ if true_labels[:, i].sum() > 0:
280
+ fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], pred_probs[:, i])
281
+ if len(fpr[i]) > 1 and len(tpr[i]) > 1:
282
+ auc_score = auc(fpr[i], tpr[i])
283
+ if not np.isnan(auc_score):
284
+ auc_labels["auc_%d" % i] = auc_score
285
+ relevant_labels.append(i)
286
+
287
+ # macro-AUC: just average the auc scores
288
+ aucs = []
289
+ for i in relevant_labels:
290
+ aucs.append(auc_labels['auc_%d' % i])
291
+ score = np.mean(aucs)
292
+ else:
293
+ # micro-AUC: just look at each individual prediction
294
+ flat_pred = pred_probs.ravel()
295
+ fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), flat_pred)
296
+ score = auc(fpr["micro"], tpr["micro"])
297
+
298
+ return score
299
+
300
+
301
+ def union_size(x, y, axis):
302
+ return np.logical_or(x, y).sum(axis=axis).astype(float)
303
+
304
+
305
+ def intersect_size(x, y, axis):
306
+ return np.logical_and(x, y).sum(axis=axis).astype(float)
307
+
308
+
309
+ def macro_accuracy(true_labels, pred_labels):
310
+ num = intersect_size(true_labels, pred_labels, 0) / (union_size(true_labels, pred_labels, 0) + 1e-10)
311
+ return np.mean(num)
312
+
313
+
314
+ def macro_precision(true_labels, pred_labels):
315
+ num = intersect_size(true_labels, pred_labels, 0) / (pred_labels.sum(axis=0) + 1e-10)
316
+ return np.mean(num)
317
+
318
+
319
+ def macro_recall(true_labels, pred_labels):
320
+ num = intersect_size(true_labels, pred_labels, 0) / (true_labels.sum(axis=0) + 1e-10)
321
+ return np.mean(num)
322
+
323
+
324
+ def macro_f1(true_labels, pred_labels):
325
+ prec = macro_precision(true_labels, pred_labels)
326
+ rec = macro_recall(true_labels, pred_labels)
327
+ if prec + rec == 0:
328
+ f1 = 0.
329
+ else:
330
+ f1 = 2 * (prec * rec) / (prec + rec)
331
+ return prec, rec, f1
332
+
333
+
334
+ def precision_at_k(true_labels, pred_probs, ks=[1, 5, 8, 10, 15]):
335
+ # num true labels in top k predictions / k
336
+ sorted_pred = np.argsort(pred_probs)[:, ::-1]
337
+ output = []
338
+ for k in ks:
339
+ topk = sorted_pred[:, :k]
340
+
341
+ # get precision at k for each example
342
+ vals = []
343
+ for i, tk in enumerate(topk):
344
+ if len(tk) > 0:
345
+ num_true_in_top_k = true_labels[i, tk].sum()
346
+ denom = len(tk)
347
+ vals.append(num_true_in_top_k / float(denom))
348
+
349
+ output.append(np.mean(vals))
350
+ return output
351
+
352
+
353
+ def micro_recall(true_labels, pred_labels):
354
+ flat_true = true_labels.ravel()
355
+ flat_pred = pred_labels.ravel()
356
+ return intersect_size(flat_true, flat_pred, 0) / flat_true.sum(axis=0)
357
+
358
+
359
+ def micro_precision(true_labels, pred_labels):
360
+ flat_true = true_labels.ravel()
361
+ flat_pred = pred_labels.ravel()
362
+ if flat_pred.sum(axis=0) == 0:
363
+ return 0.0
364
+ return intersect_size(flat_true, flat_pred, 0) / flat_pred.sum(axis=0)
365
+
366
+
367
+ def micro_f1(true_labels, pred_labels):
368
+ prec = micro_precision(true_labels, pred_labels)
369
+ rec = micro_recall(true_labels, pred_labels)
370
+ if prec + rec == 0:
371
+ f1 = 0.
372
+ else:
373
+ f1 = 2 * (prec * rec) / (prec + rec)
374
+ return prec, rec, f1
375
+
376
+
377
+ def micro_accuracy(true_labels, pred_labels):
378
+ flat_true = true_labels.ravel()
379
+ flat_pred = pred_labels.ravel()
380
+ return intersect_size(flat_true, flat_pred, 0) / union_size(flat_true, flat_pred, 0)
381
+
382
+
383
+ def calculate_scores(true_labels, logits, average="macro", is_multilabel=True, threshold=0.5):
384
+ def sigmoid(x):
385
+ return 1 / (1 + np.exp(-x))
386
+
387
+ pred_probs = sigmoid(logits)
388
+ pred_labels = np.rint(pred_probs - threshold + 0.5)
389
+
390
+ max_size = min(len(true_labels), len(pred_labels))
391
+ true_labels = true_labels[: max_size]
392
+ pred_labels = pred_labels[: max_size]
393
+ pred_probs = pred_probs[: max_size]
394
+ p_1 = 0
395
+ p_5 = 0
396
+ p_8 = 0
397
+ p_10 = 0
398
+ p_15 = 0
399
+ if pred_probs is not None:
400
+ if not is_multilabel:
401
+ normalised_labels = normalise_labels(true_labels, len(pred_probs[0]))
402
+ auc_score = roc_auc(normalised_labels, pred_probs, average=average)
403
+ accuracy = accuracy_score(true_labels, pred_labels)
404
+ precision = precision_score(true_labels, pred_labels, average=average)
405
+ recall = recall_score(true_labels, pred_labels, average=average)
406
+ f1 = f1_score(true_labels, pred_labels, average=average)
407
+ else:
408
+ if average == "macro":
409
+ accuracy = macro_accuracy(true_labels, pred_labels) # categorical accuracy
410
+ precision, recall, f1 = macro_f1(true_labels, pred_labels)
411
+ p_ks = precision_at_k(true_labels, pred_probs, [1, 5, 8, 10, 15])
412
+ p_1 = p_ks[0]
413
+ p_5 = p_ks[1]
414
+ p_8 = p_ks[2]
415
+ p_10 = p_ks[3]
416
+ p_15 = p_ks[4]
417
+
418
+ else:
419
+ accuracy = micro_accuracy(true_labels, pred_labels)
420
+ precision, recall, f1 = micro_f1(true_labels, pred_labels)
421
+ auc_score = roc_auc(true_labels, pred_probs, average)
422
+
423
+ # Calculate label-wise F1 scores
424
+ labelwise_f1 = f1_score(true_labels, pred_labels, average=None)
425
+ labelwise_f1 = np.array2string(labelwise_f1, separator=',')
426
+
427
+
428
+ else:
429
+ auc_score = -1
430
+
431
+ output = {"{}_precision".format(average): precision, "{}_recall".format(average): recall,
432
+ "{}_f1".format(average): f1, "{}_accuracy".format(average): accuracy,
433
+ "{}_auc".format(average): auc_score, "{}_P@1".format(average): p_1, "{}_P@5".format(average): p_5,
434
+ "{}_P@8".format(average): p_8, "{}_P@10".format(average): p_10, "{}_P@15".format(average): p_15,
435
+ "labelwise_f1": labelwise_f1
436
+ }
437
+
438
+ return output
439
+
440
+