meghanaraok commited on
Commit
efcc711
1 Parent(s): df6e4da

Delete models

Browse files
Files changed (3) hide show
  1. models/__init__.py +0 -0
  2. models/modeling.py +0 -343
  3. models/utils.py +0 -440
models/__init__.py DELETED
File without changes
models/modeling.py DELETED
@@ -1,343 +0,0 @@
1
- import collections
2
- import logging
3
-
4
- import torch
5
- from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
- from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
7
- from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
8
- from huggingface_hub import PyTorchModelHubMixin
9
-
10
- from LongLAT.hilat.models.utils import initial_code_title_vectors
11
-
12
- logger = logging.getLogger("lwat")
13
-
14
-
15
- class CodingModelConfig:
16
- def __init__(self,
17
- transformer_model_name_or_path,
18
- transformer_tokenizer_name,
19
- transformer_layer_update_strategy,
20
- num_chunks,
21
- max_seq_length,
22
- dropout,
23
- dropout_att,
24
- d_model,
25
- label_dictionary,
26
- num_labels,
27
- use_code_representation,
28
- code_max_seq_length,
29
- code_batch_size,
30
- multi_head_att,
31
- chunk_att,
32
- linear_init_mean,
33
- linear_init_std,
34
- document_pooling_strategy,
35
- multi_head_chunk_attention,
36
- num_hidden_layers):
37
- super(CodingModelConfig, self).__init__()
38
- self.transformer_model_name_or_path = transformer_model_name_or_path
39
- self.transformer_tokenizer_name = transformer_tokenizer_name
40
- self.transformer_layer_update_strategy = transformer_layer_update_strategy
41
- self.num_chunks = num_chunks
42
- self.max_seq_length = max_seq_length
43
- self.dropout = dropout
44
- self.dropout_att = dropout_att
45
- self.d_model = d_model
46
- # labels_dictionary is a dataframe with columns: icd9_code, long_title
47
- self.label_dictionary = label_dictionary
48
- self.num_labels = num_labels
49
- self.use_code_representation = use_code_representation
50
- self.code_max_seq_length = code_max_seq_length
51
- self.code_batch_size = code_batch_size
52
- self.multi_head_att = multi_head_att
53
- self.chunk_att = chunk_att
54
- self.linear_init_mean = linear_init_mean
55
- self.linear_init_std = linear_init_std
56
- self.document_pooling_strategy = document_pooling_strategy
57
- self.multi_head_chunk_attention = multi_head_chunk_attention
58
- self.num_hidden_layers = num_hidden_layers
59
-
60
-
61
- class LableWiseAttentionLayer(torch.nn.Module):
62
- def __init__(self, coding_model_config, args):
63
- super(LableWiseAttentionLayer, self).__init__()
64
-
65
- self.config = coding_model_config
66
- self.args = args
67
-
68
- # layers
69
- self.l1_linear = torch.nn.Linear(self.config.d_model,
70
- self.config.d_model, bias=False)
71
- self.tanh = torch.nn.Tanh()
72
- self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
73
- self.softmax = torch.nn.Softmax(dim=1)
74
-
75
- # Mean pooling last hidden state of code title from transformer model as the initial code vectors
76
- self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
77
-
78
- def _init_linear_weights(self, mean, std):
79
- # normalize the l1 weights
80
- torch.nn.init.normal_(self.l1_linear.weight, mean, std)
81
- if self.l1_linear.bias is not None:
82
- self.l1_linear.bias.data.fill_(0)
83
- # initialize the l2
84
- if self.config.use_code_representation:
85
- code_vectors = initial_code_title_vectors(self.config.label_dictionary,
86
- self.config.transformer_model_name_or_path,
87
- self.config.transformer_tokenizer_name
88
- if self.config.transformer_tokenizer_name
89
- else self.config.transformer_model_name_or_path,
90
- self.config.code_max_seq_length,
91
- self.config.code_batch_size,
92
- self.config.d_model,
93
- self.args.device)
94
-
95
- self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
96
- torch.nn.init.normal_(self.l2_linear.weight, mean, std)
97
- if self.l2_linear.bias is not None:
98
- self.l2_linear.bias.data.fill_(0)
99
-
100
- def forward(self, x):
101
- # input: (batch_size, max_seq_length, transformer_hidden_size)
102
- # output: (batch_size, max_seq_length, transformer_hidden_size)
103
- # Z = Tan(WH)
104
- l1_output = self.tanh(self.l1_linear(x))
105
- # softmax(UZ)
106
- # l2_linear output shape: (batch_size, max_seq_length, num_labels)
107
- # attention_weight shape: (batch_size, num_labels, max_seq_length)
108
- attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
109
- # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
110
- attention_output = torch.matmul(attention_weight, x)
111
-
112
- return attention_output, attention_weight
113
-
114
- class ChunkAttentionLayer(torch.nn.Module):
115
- def __init__(self, coding_model_config, args):
116
- super(ChunkAttentionLayer, self).__init__()
117
-
118
- self.config = coding_model_config
119
- self.args = args
120
-
121
- # layers
122
- self.l1_linear = torch.nn.Linear(self.config.d_model,
123
- self.config.d_model, bias=False)
124
- self.tanh = torch.nn.Tanh()
125
- self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
126
- self.softmax = torch.nn.Softmax(dim=1)
127
-
128
- self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
129
-
130
- def _init_linear_weights(self, mean, std):
131
- # initialize the l1
132
- torch.nn.init.normal_(self.l1_linear.weight, mean, std)
133
- if self.l1_linear.bias is not None:
134
- self.l1_linear.bias.data.fill_(0)
135
- # initialize the l2
136
- torch.nn.init.normal_(self.l2_linear.weight, mean, std)
137
- if self.l2_linear.bias is not None:
138
- self.l2_linear.bias.data.fill_(0)
139
-
140
- def forward(self, x):
141
- # input: (batch_size, num_chunks, transformer_hidden_size)
142
- # output: (batch_size, num_chunks, transformer_hidden_size)
143
- # Z = Tan(WH)
144
- l1_output = self.tanh(self.l1_linear(x))
145
- # softmax(UZ)
146
- # l2_linear output shape: (batch_size, num_chunks, 1)
147
- # attention_weight shape: (batch_size, 1, num_chunks)
148
- attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
149
- # attention_output shpae: (batch_size, 1, transformer_hidden_size)
150
- attention_output = torch.matmul(attention_weight, x)
151
- return attention_output, attention_weight
152
-
153
-
154
- class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
155
- def __init__(self, coding_model_config, args, **kwargs):
156
- super(CodingModel, self).__init__()
157
- self.coding_model_config = coding_model_config
158
- # layers
159
- self.transformer_layer = AutoModel.from_pretrained('yikuan8/Clinical-Longformer')
160
- if isinstance(self.transformer_layer, XLNetModel):
161
- self.transformer_layer.config.use_mems_eval = False
162
- # if torch.cuda.is_available():
163
- # self.transformer_layer = self.transformer_layer.to(torch.device("cuda:0"))
164
- # self.transformer_layer.to(torch.device("cuda:0"))
165
- self.dropout = Dropout(p=self.coding_model_config.dropout)
166
- # self.longformer_transformer = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
167
-
168
- if self.coding_model_config.multi_head_att:
169
- # initial multi head attention according to the num_chunks
170
- self.label_wise_attention_layer = torch.nn.ModuleList(
171
- [LableWiseAttentionLayer(coding_model_config, args)
172
- for _ in range(self.coding_model_config.num_chunks)])
173
- else:
174
- self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
175
- self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
176
-
177
- # initial chunk attention
178
- if self.coding_model_config.chunk_att:
179
- if self.coding_model_config.multi_head_chunk_attention:
180
- self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
181
- for _ in range(self.coding_model_config.num_labels)])
182
- else:
183
- self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
184
-
185
- self.classifier_layer = Linear(self.coding_model_config.d_model,
186
- self.coding_model_config.num_labels)
187
- else:
188
- if self.coding_model_config.document_pooling_strategy == "flat":
189
- self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
190
- self.coding_model_config.num_labels)
191
- else: # max or mean pooling
192
- self.classifier_layer = Linear(self.coding_model_config.d_model,
193
- self.coding_model_config.num_labels)
194
- self.sigmoid = torch.nn.Sigmoid()
195
-
196
- if self.coding_model_config.transformer_layer_update_strategy == "no":
197
- self.freeze_all_transformer_layers()
198
- elif self.coding_model_config.transformer_layer_update_strategy == "last":
199
- self.freeze_all_transformer_layers()
200
- self.unfreeze_transformer_last_layers()
201
-
202
- # initialize the weights of classifier
203
- self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
204
-
205
- def _init_linear_weights(self, mean, std):
206
- torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
207
-
208
- def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
209
- # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
210
- # labels shape: (batch_size, num_labels)
211
- transformer_output = []
212
-
213
- # pass chunk by chunk into transformer layer in the batches.
214
- # input (batch_size, sequence_length)
215
- # for i in range(self.coding_model_config.num_chunks):
216
- # l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
217
- # attention_mask=attention_mask[:, i, :],
218
- # token_type_ids=token_type_ids[:, i, :])
219
- # # output hidden state shape: (batch_size, sequence_length, hidden_size)
220
- # transformer_output.append(l1_output[0])
221
-
222
- input_ids = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2])
223
- global_attention_mask = torch.zeros_like(input_ids)
224
- global_attention_positions = [1, 510, 1022, 1534, 2046, 2558, 3070, 3582, 4094]
225
- global_attention_mask[:, global_attention_positions] = 1
226
- attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
227
- token_type_ids = token_type_ids.reshape(token_type_ids.shape[0], token_type_ids.shape[1]*token_type_ids.shape[2])
228
- l1_output = self.transformer_layer(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask= global_attention_mask, token_type_ids = token_type_ids)
229
-
230
- transformer_output.append(l1_output[0])
231
- # transpose back chunk and batch size dimensions
232
- transformer_output = torch.stack(transformer_output)
233
- transformer_output = transformer_output.transpose(0, 1)
234
- # dropout transformer output
235
- l2_dropout = self.dropout(transformer_output)
236
-
237
- # config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
238
- # config.num_labels =5
239
- # config.num_hidden_layers = 1
240
- # longformer_layer = LongformerLayer(config)
241
- # # longformer_layer = longformer_layer(config)
242
- # # longformer_layer = longformer_layer.to(torch.device("cuda:0"))
243
- # l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
244
- # attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
245
- # is_index_masked = attention_mask < 0
246
- # is_index_global_attn = attention_mask > 0
247
- # is_global_attn = is_index_global_attn.flatten().any().item()
248
- # output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn)
249
- # l2_dropout = self.dropout_att(output[0]) #l2_dropout - torch.Size([4, 4096, 768])
250
- # l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
251
- # #l2_dropout - torch.Size([4, 8, 512, 768])
252
-
253
-
254
- # Label-wise attention layers
255
- # output: (batch_size, num_chunks, num_labels, hidden_size)
256
- attention_output = []
257
- attention_weights = []
258
-
259
- for i in range(self.coding_model_config.num_chunks):
260
- # input: (batch_size, max_seq_length, transformer_hidden_size)
261
- if self.coding_model_config.multi_head_att:
262
- attention_layer = self.label_wise_attention_layer[i]
263
- else:
264
- attention_layer = self.label_wise_attention_layer
265
- l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
266
- # l3_attention shape: (batch_size, num_labels, hidden_size) torch.Size([4, 5, 768])
267
- # attention_weight: (batch_size, num_labels, max_seq_length) torch.Size([4, 5, 512])
268
- attention_output.append(l3_attention)
269
- attention_weights.append(attention_weight)
270
-
271
- attention_output = torch.stack(attention_output)
272
- attention_output = attention_output.transpose(0, 1) #torch.Size([4, 8, 5, 768])
273
- attention_weights = torch.stack(attention_weights)
274
- attention_weights = attention_weights.transpose(0, 1) #torch.Size([4, 8, 5, 512])
275
-
276
- l3_dropout = self.dropout_att(attention_output) #torch.Size([4, 8, 5, 768])
277
-
278
- if self.coding_model_config.chunk_att: #set to false
279
- # Chunk attention layers
280
- # output: (batch_size, num_labels, hidden_size)
281
- chunk_attention_output = []
282
- chunk_attention_weights = []
283
-
284
- for i in range(self.coding_model_config.num_labels):
285
- if self.coding_model_config.multi_head_chunk_attention:
286
- chunk_attention = self.chunk_attention_layer[i]
287
- else:
288
- chunk_attention = self.chunk_attention_layer
289
- l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
290
- chunk_attention_output.append(l4_chunk_attention.squeeze())
291
- chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
292
-
293
- chunk_attention_output = torch.stack(chunk_attention_output) #torch.Size([5, 4, 768])
294
- chunk_attention_output = chunk_attention_output.transpose(0, 1) #torch.Size([4, 5, 768])
295
- chunk_attention_weights = torch.stack(chunk_attention_weights)
296
- chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
297
- # output shape: (batch_size, num_labels, hidden_size)
298
- l4_dropout = self.dropout_att(chunk_attention_output) #torch.Size([4, 5, 768])
299
- else:
300
- # output shape: (batch_size, num_labels, hidden_size*num_chunks)
301
- l4_dropout = l3_dropout.transpose(1, 2)
302
- if self.coding_model_config.document_pooling_strategy == "flat":
303
- # Flatten layer. concatenate representation by labels
304
- l4_dropout = torch.flatten(l4_dropout, start_dim=2)
305
- elif self.coding_model_config.document_pooling_strategy == "max":
306
- l4_dropout = torch.amax(l4_dropout, 2)
307
- elif self.coding_model_config.document_pooling_strategy == "mean":
308
- l4_dropout = torch.mean(l4_dropout, 2)
309
- else:
310
- raise ValueError("Not supported pooling strategy")
311
-
312
- # classifier layer
313
- # each code has a binary linear formula
314
- logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
315
- #torch.Size([4, 5])
316
- loss_fct = BCEWithLogitsLoss()
317
- loss = loss_fct(logits, targets)
318
-
319
- return {
320
- "loss": loss,
321
- "logits": logits,
322
- "label_attention_weights": attention_weights,
323
- "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
324
- }
325
-
326
- def freeze_all_transformer_layers(self):
327
- """
328
- Freeze all layer weight parameters. They will not be updated during training.
329
- """
330
- for param in self.transformer_layer.parameters():
331
- param.requires_grad = False
332
-
333
- def unfreeze_all_transformer_layers(self):
334
- """
335
- Unfreeze all layers weight parameters. They will be updated during training.
336
- """
337
- for param in self.transformer_layer.parameters():
338
- param.requires_grad = True
339
-
340
- def unfreeze_transformer_last_layers(self):
341
- for name, param in self.transformer_layer.named_parameters():
342
- if "layer.11" in name or "pooler" in name:
343
- param.requires_grad = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/utils.py DELETED
@@ -1,440 +0,0 @@
1
- import csv
2
- import linecache
3
- import pickle
4
- import random
5
- import subprocess
6
-
7
- import numpy as np
8
- import redis
9
- import torch
10
- import logging
11
- import ast
12
-
13
- from datasets import Dataset
14
- from tqdm import tqdm
15
-
16
- from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, roc_curve, auc
17
- from torch.utils.data import DataLoader
18
- from transformers import AutoModel, DataCollatorWithPadding, XLNetTokenizer, XLNetTokenizerFast, AutoTokenizer, \
19
- XLNetModel, is_torch_tpu_available
20
-
21
- logger = logging.getLogger("lwat")
22
-
23
-
24
- class MimicIIIDataset(Dataset):
25
- def __init__(self, data):
26
- self.input_ids = data["input_ids"]
27
- self.attention_mask = data["attention_mask"]
28
- self.token_type_ids = data["token_type_ids"]
29
- self.labels = data["targets"]
30
-
31
- def __len__(self):
32
- return len(self.input_ids)
33
-
34
- def __getitem__(self, item):
35
- return {
36
- "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
37
- "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
38
- "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long),
39
- "targets": torch.tensor(self.labels[item], dtype=torch.float)
40
- }
41
-
42
- class LazyMimicIIIDataset(Dataset):
43
- def __init__(self, filename, task, dataset_type):
44
- print("lazy load from {}".format(filename))
45
- self.filename = filename
46
- self.redis = redis.Redis(unix_socket_path="/tmp/redis.sock")
47
- self.pipe = self.redis.pipeline()
48
- self.num_examples = 0
49
- self.task = task
50
- self.dataset_type = dataset_type
51
- with open(filename, 'r') as f:
52
- for line_num, line in enumerate(f.readlines()):
53
- self.num_examples += 1
54
- example = eval(line)
55
- key = task + '_' + dataset_type + '_' + str(line_num)
56
- input_ids = eval(example[0])
57
- attention_mask = eval(example[1])
58
- token_type_ids = eval(example[2])
59
- labels = eval(example[3])
60
- example_tuple = (input_ids, attention_mask, token_type_ids, labels)
61
-
62
- self.pipe.set(key, pickle.dumps(example_tuple))
63
- if line_num % 100 == 0:
64
- self.pipe.execute()
65
- self.pipe.execute()
66
- if is_torch_tpu_available():
67
- import torch_xla.core.xla_model as xm
68
- xm.rendezvous(tag="featuresGenerated")
69
-
70
- def __len__(self):
71
- return self.num_examples
72
-
73
- def __getitem__(self, item):
74
- key = self.task + '_' + self.dataset_type + '_' + str(item)
75
- example = pickle.loads(self.redis.get(key))
76
-
77
- return {
78
- "input_ids": torch.tensor(example[0], dtype=torch.long),
79
- "attention_mask": torch.tensor(example[1], dtype=torch.float),
80
- "token_type_ids": torch.tensor(example[2], dtype=torch.long),
81
- "targets": torch.tensor(example[3], dtype=torch.float)
82
- }
83
-
84
-
85
- class ICDCodeDataset(Dataset):
86
- def __init__(self, data):
87
- self.input_ids = data["input_ids"]
88
- self.attention_mask = data["attention_mask"]
89
- self.token_type_ids = data["token_type_ids"]
90
-
91
- def __len__(self):
92
- return len(self.input_ids)
93
-
94
- def __getitem__(self, item):
95
- return {
96
- "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
97
- "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
98
- "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long)
99
- }
100
-
101
-
102
- def set_random_seed(random_seed):
103
- random.seed(random_seed)
104
- np.random.seed(random_seed)
105
- torch.manual_seed(random_seed)
106
- torch.cuda.manual_seed_all(random_seed)
107
- torch.backends.cudnn.deterministic = True
108
- torch.backends.cudnn.benchmark = False
109
-
110
- def tokenize_inputs(text_list, tokenizer, max_seq_len=512):
111
- """
112
- Tokenizes the input text input into ids. Appends the appropriate special
113
- characters to the end of the text to denote end of sentence. Truncate or pad
114
- the appropriate sequence length.
115
- """
116
- # tokenize the text, then truncate sequence to the desired length minus 2 for
117
- # the 2 special characters
118
- tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:max_seq_len - 2], text_list))
119
- # convert tokenized text into numeric ids for the appropriate LM
120
- input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
121
- # get token type for token_ids_0
122
- token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
123
- # append special token to end of sentence: <sep> <cls>
124
- input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
125
- # attention mask
126
- attention_mask = [[1] * len(x) for x in input_ids]
127
-
128
- # padding to max_length
129
- def padding_to_max(sequence, value):
130
- padding_len = max_seq_len - len(sequence)
131
- padding = [value] * padding_len
132
- return sequence + padding
133
-
134
- input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
135
- attention_mask = [padding_to_max(x, 0) for x in attention_mask]
136
- token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
137
-
138
- return input_ids, attention_mask, token_type_ids
139
-
140
-
141
- def tokenize_dataset(tokenizer, text, labels, max_seq_len):
142
- if (isinstance(tokenizer, XLNetTokenizer) or isinstance(tokenizer, XLNetTokenizerFast)):
143
- data = list(map(lambda t: tokenize_inputs(t, tokenizer, max_seq_len=max_seq_len), text))
144
- input_ids, attention_mask, token_type_ids = zip(*data)
145
- else:
146
- tokenizer.model_max_length = max_seq_len
147
- input_dict = tokenizer(text, padding=True, truncation=True)
148
- input_ids = input_dict["input_ids"]
149
- attention_mask = input_dict["attention_mask"]
150
- token_type_ids = input_dict["token_type_ids"]
151
-
152
- return {
153
- "input_ids": input_ids,
154
- "attention_mask": attention_mask,
155
- "token_type_ids": token_type_ids,
156
- "targets": labels
157
- }
158
-
159
-
160
- def initial_code_title_vectors(label_dict, transformer_model_name, tokenizer_name, code_max_seq_length, code_batch_size,
161
- d_model, device):
162
- logger.info("Generate code title representations from base transformer model")
163
- model = AutoModel.from_pretrained(transformer_model_name)
164
- if isinstance(model, XLNetModel):
165
- model.config.use_mems_eval = False
166
- #
167
- # model.config.use_mems_eval = False
168
- # model.config.reuse_len = 0
169
- code_titles = label_dict["long_title"].fillna("").tolist()
170
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, padding_side="right")
171
- data = tokenizer(code_titles, padding=True, truncation=True)
172
- code_dataset = ICDCodeDataset(data)
173
-
174
- model.to(device)
175
-
176
- data_collator = DataCollatorWithPadding(tokenizer, padding="max_length",
177
- max_length=code_max_seq_length)
178
- code_param = {"batch_size": code_batch_size, "collate_fn": data_collator}
179
- code_dataloader = DataLoader(code_dataset, **code_param)
180
-
181
- code_dataloader_progress_bar = tqdm(code_dataloader, unit="batches",
182
- desc="Code title representations")
183
- code_dataloader_progress_bar.clear()
184
-
185
- # output shape: (num_labels, hidden_size)
186
- initial_code_vectors = torch.zeros(len(code_dataset), d_model)
187
-
188
- for i, data in enumerate(code_dataloader_progress_bar):
189
- input_ids = data["input_ids"].to(device, dtype=torch.long)
190
- attention_mask = data["attention_mask"].to(device, dtype=torch.float)
191
- token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
192
-
193
- # output shape: (batch_size, sequence_length, hidden_size)
194
- output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
195
- # Mean pooling. output shape: (batch_size, hidden_size)
196
- mean_last_hidden_state = torch.mean(output[0], 1)
197
- # Max pooling. output shape: (batch_size, hidden_size)
198
- # max_last_hidden_state = torch.max((output[0] * attention_mask.unsqueeze(-1)), 1)[0]
199
-
200
- initial_code_vectors[i * input_ids.shape[0]:(i + 1) * input_ids.shape[0], :] = mean_last_hidden_state
201
-
202
- code_dataloader_progress_bar.refresh(True)
203
- code_dataloader_progress_bar.clear(True)
204
- code_dataloader_progress_bar.close()
205
- logger.info("Code representations ready for use. Shape {}".format(initial_code_vectors.shape))
206
- return initial_code_vectors
207
-
208
-
209
- def normalise_labels(labels, n_label):
210
- norm_labels = []
211
- for label in labels:
212
- one_hot_vector_label = [0] * n_label
213
- one_hot_vector_label[label] = 1
214
- norm_labels.append(one_hot_vector_label)
215
- return np.asarray(norm_labels)
216
-
217
-
218
- def segment_tokenize_inputs(text, tokenizer, max_seq_len, num_chunks):
219
- # input is full text of one document
220
- tokenized_texts = []
221
- tokens = tokenizer.tokenize(text)
222
- start_idx = 0
223
- seq_len = max_seq_len - 2
224
- for i in range(num_chunks):
225
- if start_idx > len(tokens):
226
- tokenized_texts.append([])
227
- continue
228
- tokenized_texts.append(tokens[start_idx:(start_idx + seq_len)])
229
- start_idx += seq_len
230
-
231
- # convert tokenized text into numeric ids for the appropriate LM
232
- input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
233
- # get token type for token_ids_0
234
- token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
235
- # append special token to end of sentence: <sep> <cls>
236
- input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
237
- # attention mask
238
- attention_mask = [[1] * len(x) for x in input_ids]
239
-
240
- # padding to max_length
241
- def padding_to_max(sequence, value):
242
- padding_len = max_seq_len - len(sequence)
243
- padding = [value] * padding_len
244
- return sequence + padding
245
-
246
- input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
247
- attention_mask = [padding_to_max(x, 0) for x in attention_mask]
248
- token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
249
-
250
- return input_ids, attention_mask, token_type_ids
251
-
252
-
253
- def segment_tokenize_dataset(tokenizer, text, labels, max_seq_len, num_chunks):
254
- data = list(
255
- map(lambda t: segment_tokenize_inputs(t, tokenizer, max_seq_len, num_chunks), text))
256
- input_ids, attention_mask, token_type_ids = zip(*data)
257
-
258
- return {
259
- "input_ids": input_ids,
260
- "attention_mask": attention_mask,
261
- "token_type_ids": token_type_ids,
262
- "targets": labels
263
- }
264
-
265
-
266
- # The following functions are modified from the relevant codes of https://github.com/aehrc/LAAT
267
- def roc_auc(true_labels, pred_probs, average="macro"):
268
- if pred_probs.shape[0] <= 1:
269
- return
270
-
271
- fpr = {}
272
- tpr = {}
273
- if average == "macro":
274
- # get AUC for each label individually
275
- relevant_labels = []
276
- auc_labels = {}
277
- for i in range(true_labels.shape[1]):
278
- # only if there are true positives for this label
279
- if true_labels[:, i].sum() > 0:
280
- fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], pred_probs[:, i])
281
- if len(fpr[i]) > 1 and len(tpr[i]) > 1:
282
- auc_score = auc(fpr[i], tpr[i])
283
- if not np.isnan(auc_score):
284
- auc_labels["auc_%d" % i] = auc_score
285
- relevant_labels.append(i)
286
-
287
- # macro-AUC: just average the auc scores
288
- aucs = []
289
- for i in relevant_labels:
290
- aucs.append(auc_labels['auc_%d' % i])
291
- score = np.mean(aucs)
292
- else:
293
- # micro-AUC: just look at each individual prediction
294
- flat_pred = pred_probs.ravel()
295
- fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), flat_pred)
296
- score = auc(fpr["micro"], tpr["micro"])
297
-
298
- return score
299
-
300
-
301
- def union_size(x, y, axis):
302
- return np.logical_or(x, y).sum(axis=axis).astype(float)
303
-
304
-
305
- def intersect_size(x, y, axis):
306
- return np.logical_and(x, y).sum(axis=axis).astype(float)
307
-
308
-
309
- def macro_accuracy(true_labels, pred_labels):
310
- num = intersect_size(true_labels, pred_labels, 0) / (union_size(true_labels, pred_labels, 0) + 1e-10)
311
- return np.mean(num)
312
-
313
-
314
- def macro_precision(true_labels, pred_labels):
315
- num = intersect_size(true_labels, pred_labels, 0) / (pred_labels.sum(axis=0) + 1e-10)
316
- return np.mean(num)
317
-
318
-
319
- def macro_recall(true_labels, pred_labels):
320
- num = intersect_size(true_labels, pred_labels, 0) / (true_labels.sum(axis=0) + 1e-10)
321
- return np.mean(num)
322
-
323
-
324
- def macro_f1(true_labels, pred_labels):
325
- prec = macro_precision(true_labels, pred_labels)
326
- rec = macro_recall(true_labels, pred_labels)
327
- if prec + rec == 0:
328
- f1 = 0.
329
- else:
330
- f1 = 2 * (prec * rec) / (prec + rec)
331
- return prec, rec, f1
332
-
333
-
334
- def precision_at_k(true_labels, pred_probs, ks=[1, 5, 8, 10, 15]):
335
- # num true labels in top k predictions / k
336
- sorted_pred = np.argsort(pred_probs)[:, ::-1]
337
- output = []
338
- for k in ks:
339
- topk = sorted_pred[:, :k]
340
-
341
- # get precision at k for each example
342
- vals = []
343
- for i, tk in enumerate(topk):
344
- if len(tk) > 0:
345
- num_true_in_top_k = true_labels[i, tk].sum()
346
- denom = len(tk)
347
- vals.append(num_true_in_top_k / float(denom))
348
-
349
- output.append(np.mean(vals))
350
- return output
351
-
352
-
353
- def micro_recall(true_labels, pred_labels):
354
- flat_true = true_labels.ravel()
355
- flat_pred = pred_labels.ravel()
356
- return intersect_size(flat_true, flat_pred, 0) / flat_true.sum(axis=0)
357
-
358
-
359
- def micro_precision(true_labels, pred_labels):
360
- flat_true = true_labels.ravel()
361
- flat_pred = pred_labels.ravel()
362
- if flat_pred.sum(axis=0) == 0:
363
- return 0.0
364
- return intersect_size(flat_true, flat_pred, 0) / flat_pred.sum(axis=0)
365
-
366
-
367
- def micro_f1(true_labels, pred_labels):
368
- prec = micro_precision(true_labels, pred_labels)
369
- rec = micro_recall(true_labels, pred_labels)
370
- if prec + rec == 0:
371
- f1 = 0.
372
- else:
373
- f1 = 2 * (prec * rec) / (prec + rec)
374
- return prec, rec, f1
375
-
376
-
377
- def micro_accuracy(true_labels, pred_labels):
378
- flat_true = true_labels.ravel()
379
- flat_pred = pred_labels.ravel()
380
- return intersect_size(flat_true, flat_pred, 0) / union_size(flat_true, flat_pred, 0)
381
-
382
-
383
- def calculate_scores(true_labels, logits, average="macro", is_multilabel=True, threshold=0.5):
384
- def sigmoid(x):
385
- return 1 / (1 + np.exp(-x))
386
-
387
- pred_probs = sigmoid(logits)
388
- pred_labels = np.rint(pred_probs - threshold + 0.5)
389
-
390
- max_size = min(len(true_labels), len(pred_labels))
391
- true_labels = true_labels[: max_size]
392
- pred_labels = pred_labels[: max_size]
393
- pred_probs = pred_probs[: max_size]
394
- p_1 = 0
395
- p_5 = 0
396
- p_8 = 0
397
- p_10 = 0
398
- p_15 = 0
399
- if pred_probs is not None:
400
- if not is_multilabel:
401
- normalised_labels = normalise_labels(true_labels, len(pred_probs[0]))
402
- auc_score = roc_auc(normalised_labels, pred_probs, average=average)
403
- accuracy = accuracy_score(true_labels, pred_labels)
404
- precision = precision_score(true_labels, pred_labels, average=average)
405
- recall = recall_score(true_labels, pred_labels, average=average)
406
- f1 = f1_score(true_labels, pred_labels, average=average)
407
- else:
408
- if average == "macro":
409
- accuracy = macro_accuracy(true_labels, pred_labels) # categorical accuracy
410
- precision, recall, f1 = macro_f1(true_labels, pred_labels)
411
- p_ks = precision_at_k(true_labels, pred_probs, [1, 5, 8, 10, 15])
412
- p_1 = p_ks[0]
413
- p_5 = p_ks[1]
414
- p_8 = p_ks[2]
415
- p_10 = p_ks[3]
416
- p_15 = p_ks[4]
417
-
418
- else:
419
- accuracy = micro_accuracy(true_labels, pred_labels)
420
- precision, recall, f1 = micro_f1(true_labels, pred_labels)
421
- auc_score = roc_auc(true_labels, pred_probs, average)
422
-
423
- # Calculate label-wise F1 scores
424
- labelwise_f1 = f1_score(true_labels, pred_labels, average=None)
425
- labelwise_f1 = np.array2string(labelwise_f1, separator=',')
426
-
427
-
428
- else:
429
- auc_score = -1
430
-
431
- output = {"{}_precision".format(average): precision, "{}_recall".format(average): recall,
432
- "{}_f1".format(average): f1, "{}_accuracy".format(average): accuracy,
433
- "{}_auc".format(average): auc_score, "{}_P@1".format(average): p_1, "{}_P@5".format(average): p_5,
434
- "{}_P@8".format(average): p_8, "{}_P@10".format(average): p_10, "{}_P@15".format(average): p_15,
435
- "labelwise_f1": labelwise_f1
436
- }
437
-
438
- return output
439
-
440
-