manan commited on
Commit
3107fce
1 Parent(s): 30543b8

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +252 -0
model.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm.notebook import tqdm, trange
6
+
7
+ import torch
8
+ from torch import nn
9
+ import transformers
10
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
11
+
12
+
13
+ config = dict(
14
+ # basic
15
+ seed = 3407,
16
+ num_jobs=1,
17
+ num_labels=2,
18
+
19
+ # model info
20
+ tokenizer_path = 'allenai/biomed_roberta_base', # 'roberta-base',
21
+ model_checkpoint = '../input/biomed-roberta', # 'roberta-base',
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu',
23
+
24
+ # training paramters
25
+ max_length = 512,
26
+ batch_size=16,
27
+
28
+ # for this notebook
29
+ debug = False,
30
+ )
31
+
32
+
33
+ def create_sample_test():
34
+ feats = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/features.csv")
35
+ feats.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
36
+
37
+ notes = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/patient_notes.csv")
38
+ test = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/test.csv")
39
+
40
+ merged = test.merge(notes, how = "left")
41
+ merged = merged.merge(feats, how = "left")
42
+
43
+ def process_feature_text(text):
44
+ return text.replace("-OR-", ";-").replace("-", " ")
45
+ merged["feature_text"] = [process_feature_text(x) for x in merged["feature_text"]]
46
+
47
+ return merged.sample(1).reset_index(drop=True)
48
+
49
+ class NBMETestData(torch.utils.data.Dataset):
50
+ def __init__(self, feature_text, pn_history, tokenizer):
51
+ self.feature_text = feature_text
52
+ self.pn_history = pn_history
53
+ self.tokenizer = tokenizer
54
+
55
+ def __len__(self):
56
+ return len(self.feature_text)
57
+
58
+ def __getitem__(self, idx):
59
+ tokenized = self.tokenizer(
60
+ self.feature_text[idx],
61
+ self.pn_history[idx],
62
+ truncation = "only_second",
63
+ max_length = config['max_length'],
64
+ padding = "max_length",
65
+ return_offsets_mapping = True
66
+ )
67
+ tokenized["sequence_ids"] = tokenized.sequence_ids()
68
+
69
+ input_ids = np.array(tokenized["input_ids"])
70
+ attention_mask = np.array(tokenized["attention_mask"])
71
+ offset_mapping = np.array(tokenized["offset_mapping"])
72
+ sequence_ids = np.array(tokenized["sequence_ids"]).astype("float16")
73
+
74
+ return {
75
+ 'input_ids': input_ids,
76
+ 'attention_mask': attention_mask,
77
+ 'offset_mapping': offset_mapping,
78
+ 'sequence_ids': sequence_ids,
79
+ }
80
+
81
+ class NBMEModel(nn.Module):
82
+ def __init__(self, num_labels=1, path=None):
83
+ super().__init__()
84
+
85
+ layer_norm_eps: float = 1e-6
86
+
87
+ self.path = path
88
+ self.num_labels = num_labels
89
+ self.config = transformers.AutoConfig.from_pretrained(config['model_checkpoint'])
90
+
91
+ self.config.update(
92
+ {
93
+ "layer_norm_eps": layer_norm_eps,
94
+ }
95
+ )
96
+ self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'], config=self.config)
97
+ self.dropout = nn.Dropout(0.2)
98
+ self.output = nn.Linear(self.config.hidden_size, 1)
99
+
100
+ if self.path is not None:
101
+ self.load_state_dict(torch.load(self.path)['model'])
102
+
103
+ def forward(self, data):
104
+
105
+ ids = data['input_ids']
106
+ mask = data['attention_mask']
107
+ try:
108
+ target = data['targets']
109
+ except:
110
+ target = None
111
+
112
+ transformer_out = self.transformer(ids, mask)
113
+ sequence_output = transformer_out[0]
114
+ sequence_output = self.dropout(sequence_output)
115
+ logits = self.output(sequence_output)
116
+
117
+ ret = {
118
+ "logits": torch.sigmoid(logits),
119
+ }
120
+
121
+ if target is not None:
122
+ loss = self.get_loss(logits, target)
123
+ ret['loss'] = loss
124
+ ret['targets'] = target
125
+
126
+ return ret
127
+
128
+
129
+ def get_optimizer(self, learning_rate, weigth_decay):
130
+ optimizer = torch.optim.AdamW(
131
+ self.parameters(),
132
+ lr=learning_rate,
133
+ weight_decay=weigth_decay,
134
+ )
135
+ if self.path is not None:
136
+ optimizer.load_state_dict(torch.load(self.path)['optimizer'])
137
+
138
+ return optimizer
139
+
140
+ def get_scheduler(self, optimizer, num_warmup_steps, num_training_steps):
141
+ scheduler = transformers.get_linear_schedule_with_warmup(
142
+ optimizer,
143
+ num_warmup_steps=num_warmup_steps,
144
+ num_training_steps=num_training_steps,
145
+ )
146
+ if self.path is not None:
147
+ scheduler.load_state_dict(torch.load(self.path)['scheduler'])
148
+
149
+ return scheduler
150
+
151
+ def get_loss(self, output, target):
152
+ loss_fn = nn.BCEWithLogitsLoss(reduction="none")
153
+ loss = loss_fn(output.view(-1, 1), target.view(-1, 1))
154
+ loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
155
+ return loss
156
+
157
+ def get_location_predictions(preds, offset_mapping, sequence_ids, test=False):
158
+ all_predictions = []
159
+ for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):
160
+ start_idx = None
161
+ current_preds = []
162
+ for p, o, s_id in zip(pred, offsets, seq_ids):
163
+ if s_id is None or s_id == 0:
164
+ continue
165
+ if p > 0.5:
166
+ if start_idx is None:
167
+ start_idx = o[0]
168
+ end_idx = o[1]
169
+ elif start_idx is not None:
170
+ if test:
171
+ current_preds.append(f"{start_idx} {end_idx}")
172
+ else:
173
+ current_preds.append((start_idx, end_idx))
174
+ start_idx = None
175
+ if test:
176
+ all_predictions.append("; ".join(current_preds))
177
+ else:
178
+ all_predictions.append(current_preds)
179
+ return all_predictions
180
+
181
+
182
+
183
+ def predict_location_preds(tokenizer, model, feature_text, pn_history):
184
+
185
+ test_ds = NBMETestData(feature_text, pn_history, tokenizer)
186
+ test_dl = torch.utils.data.DataLoader(
187
+ test_ds,
188
+ batch_size=config['batch_size'],
189
+ pin_memory=True,
190
+ shuffle=False,
191
+ drop_last=False
192
+ )
193
+
194
+ all_preds = None
195
+ offsets = []
196
+ seq_ids = []
197
+
198
+ preds = []
199
+
200
+ with torch.no_grad():
201
+ for batch in tqdm(test_dl):
202
+
203
+ for k, v in batch.items():
204
+ if k not in ['offset_mapping', 'sequence_id']:
205
+ batch[k] = v.to(config['device'])
206
+
207
+ logits = model(batch)['logits']
208
+ preds.append(logits.cpu().numpy())
209
+
210
+ offset_mapping = batch['offset_mapping']
211
+ sequence_ids = batch['sequence_ids']
212
+ offsets.append(offset_mapping.cpu().numpy())
213
+ seq_ids.append(sequence_ids.cpu().numpy())
214
+
215
+ preds = np.concatenate(preds, axis=0)
216
+ if all_preds is None:
217
+ all_preds = np.array(preds).astype(np.float32)
218
+ else:
219
+ all_preds += np.array(preds).astype(np.float32)
220
+ torch.cuda.empty_cache()
221
+
222
+ all_preds = all_preds.squeeze()
223
+
224
+ offsets = np.concatenate(offsets, axis=0)
225
+ seq_ids = np.concatenate(seq_ids, axis=0)
226
+
227
+ # print(all_preds.shape, offsets.shape, seq_ids.shape)
228
+
229
+ location_preds = get_location_predictions([all_preds], offsets, seq_ids, test=False)[0]
230
+
231
+ x = []
232
+
233
+ for location in location_preds:
234
+ x.append(pn_history[0][location[0]: location[1]])
235
+
236
+ return location_preds, ', '.join(x)
237
+
238
+ def get_predictions(feature_text, pn_history):
239
+ location_preds, pred_string = predict_location_preds(tokenizer, model, [feature_text], [pn_history])
240
+ print(pred_string)
241
+
242
+ tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_path'])
243
+ path = '../input/nbme-training-biomed-roberta-base/best_model_0.bin'
244
+
245
+ model = NBMEModel().to(config['device'])
246
+ model.load_state_dict(torch.load(path, map_location=torch.device(config['device']))['model'])
247
+ model.eval();
248
+
249
+ # input_text = create_sample_test()
250
+ # feature_text = input_text.feature_text[0]
251
+ # pn_history = input_text.pn_history[0]
252
+ # get_predictions(feature_text, pn_history)