manan commited on
Commit
4aa121f
1 Parent(s): 20de366

model replaced with large

Browse files
Files changed (1) hide show
  1. model.py +109 -11
model.py CHANGED
@@ -17,8 +17,8 @@ config = dict(
17
  num_labels=2,
18
 
19
  # model info
20
- tokenizer_path = 'allenai/biomed_roberta_base', # 'roberta-base',
21
- model_checkpoint = 'allenai/biomed_roberta_base', # 'roberta-base',
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu',
23
 
24
  # training paramters
@@ -78,22 +78,106 @@ class NBMETestData(torch.utils.data.Dataset):
78
  'sequence_ids': sequence_ids,
79
  }
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  class NBMEModel(nn.Module):
82
- def __init__(self, num_labels=1, path=None):
83
  super().__init__()
84
 
85
  layer_norm_eps: float = 1e-6
86
 
87
  self.path = path
88
  self.num_labels = num_labels
 
 
 
 
 
 
 
 
 
89
 
90
- self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'])
91
- self.dropout = nn.Dropout(0.2)
92
- self.output = nn.Linear(768, 1)
 
 
 
 
93
 
94
  if self.path is not None:
95
  self.load_state_dict(torch.load(self.path)['model'])
96
-
97
  def forward(self, data):
98
 
99
  ids = data['input_ids']
@@ -106,16 +190,29 @@ class NBMEModel(nn.Module):
106
  transformer_out = self.transformer(ids, mask)
107
  sequence_output = transformer_out[0]
108
  sequence_output = self.dropout(sequence_output)
109
- logits = self.output(sequence_output)
 
 
 
 
 
110
 
 
111
  ret = {
112
- "logits": torch.sigmoid(logits),
113
  }
114
 
 
 
115
  if target is not None:
116
- loss = self.get_loss(logits, target)
 
 
 
 
 
117
  ret['loss'] = loss
118
- ret['targets'] = target
119
 
120
  return ret
121
 
@@ -148,6 +245,7 @@ class NBMEModel(nn.Module):
148
  loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
149
  return loss
150
 
 
151
  def get_location_predictions(preds, offset_mapping, sequence_ids, test=False):
152
  all_predictions = []
153
  for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):
 
17
  num_labels=2,
18
 
19
  # model info
20
+ tokenizer_path = 'roberta-large', # 'allenai/biomed_roberta_base',
21
+ model_checkpoint = 'model_large_pseudo_label.pth', # 'allenai/biomed_roberta_base',
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu',
23
 
24
  # training paramters
 
78
  'sequence_ids': sequence_ids,
79
  }
80
 
81
+ # class NBMEModel(nn.Module):
82
+ # def __init__(self, num_labels=1, path=None):
83
+ # super().__init__()
84
+
85
+ # layer_norm_eps: float = 1e-6
86
+
87
+ # self.path = path
88
+ # self.num_labels = num_labels
89
+
90
+ # self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'])
91
+ # self.dropout = nn.Dropout(0.2)
92
+ # self.output = nn.Linear(768, 1)
93
+
94
+ # if self.path is not None:
95
+ # self.load_state_dict(torch.load(self.path)['model'])
96
+
97
+ # def forward(self, data):
98
+
99
+ # ids = data['input_ids']
100
+ # mask = data['attention_mask']
101
+ # try:
102
+ # target = data['targets']
103
+ # except:
104
+ # target = None
105
+
106
+ # transformer_out = self.transformer(ids, mask)
107
+ # sequence_output = transformer_out[0]
108
+ # sequence_output = self.dropout(sequence_output)
109
+ # logits = self.output(sequence_output)
110
+
111
+ # ret = {
112
+ # "logits": torch.sigmoid(logits),
113
+ # }
114
+
115
+ # if target is not None:
116
+ # loss = self.get_loss(logits, target)
117
+ # ret['loss'] = loss
118
+ # ret['targets'] = target
119
+
120
+ # return ret
121
+
122
+
123
+ # def get_optimizer(self, learning_rate, weigth_decay):
124
+ # optimizer = torch.optim.AdamW(
125
+ # self.parameters(),
126
+ # lr=learning_rate,
127
+ # weight_decay=weigth_decay,
128
+ # )
129
+ # if self.path is not None:
130
+ # optimizer.load_state_dict(torch.load(self.path)['optimizer'])
131
+
132
+ # return optimizer
133
+
134
+ # def get_scheduler(self, optimizer, num_warmup_steps, num_training_steps):
135
+ # scheduler = transformers.get_linear_schedule_with_warmup(
136
+ # optimizer,
137
+ # num_warmup_steps=num_warmup_steps,
138
+ # num_training_steps=num_training_steps,
139
+ # )
140
+ # if self.path is not None:
141
+ # scheduler.load_state_dict(torch.load(self.path)['scheduler'])
142
+
143
+ # return scheduler
144
+
145
+ # def get_loss(self, output, target):
146
+ # loss_fn = nn.BCEWithLogitsLoss(reduction="none")
147
+ # loss = loss_fn(output.view(-1, 1), target.view(-1, 1))
148
+ # loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
149
+ # return loss
150
+
151
+
152
  class NBMEModel(nn.Module):
153
+ def __init__(self, num_labels=2, path=None):
154
  super().__init__()
155
 
156
  layer_norm_eps: float = 1e-6
157
 
158
  self.path = path
159
  self.num_labels = num_labels
160
+ self.config = transformers.AutoConfig.from_pretrained(config['model_checkpoint'])
161
+
162
+ self.config.update(
163
+ {
164
+ "layer_norm_eps": layer_norm_eps,
165
+ }
166
+ )
167
+ self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'], config=self.config)
168
+ self.dropout = nn.Dropout(0.1)
169
 
170
+ self.dropout1 = nn.Dropout(0.1)
171
+ self.dropout2 = nn.Dropout(0.2)
172
+ self.dropout3 = nn.Dropout(0.3)
173
+ self.dropout4 = nn.Dropout(0.4)
174
+ self.dropout5 = nn.Dropout(0.5)
175
+
176
+ self.output = nn.Linear(self.config.hidden_size, 1)
177
 
178
  if self.path is not None:
179
  self.load_state_dict(torch.load(self.path)['model'])
180
+
181
  def forward(self, data):
182
 
183
  ids = data['input_ids']
 
190
  transformer_out = self.transformer(ids, mask)
191
  sequence_output = transformer_out[0]
192
  sequence_output = self.dropout(sequence_output)
193
+
194
+ logits1 = self.output(self.dropout1(sequence_output))
195
+ logits2 = self.output(self.dropout2(sequence_output))
196
+ logits3 = self.output(self.dropout3(sequence_output))
197
+ logits4 = self.output(self.dropout4(sequence_output))
198
+ logits5 = self.output(self.dropout5(sequence_output))
199
 
200
+ logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
201
  ret = {
202
+ 'logits': torch.sigmoid(logits),
203
  }
204
 
205
+ loss = 0
206
+
207
  if target is not None:
208
+ loss1 = self.get_loss(logits1, target)
209
+ loss2 = self.get_loss(logits2, target)
210
+ loss3 = self.get_loss(logits3, target)
211
+ loss4 = self.get_loss(logits4, target)
212
+ loss5 = self.get_loss(logits5, target)
213
+ loss = (loss1 + loss2 + loss3 + loss4 + loss5) / 5
214
  ret['loss'] = loss
215
+ ret['target'] = target
216
 
217
  return ret
218
 
 
245
  loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
246
  return loss
247
 
248
+
249
  def get_location_predictions(preds, offset_mapping, sequence_ids, test=False):
250
  all_predictions = []
251
  for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):