sxtforreal commited on
Commit
d09e211
1 Parent(s): 9af29de

Upload 3 files

Browse files
Files changed (3) hide show
  1. dataset.py +120 -657
  2. loss.py +24 -177
  3. model.py +97 -393
dataset.py CHANGED
@@ -1,21 +1,66 @@
1
- import pandas as pd
2
  import torch
 
3
  from torch.utils.data import Dataset, DataLoader
4
  from torch.nn.utils.rnn import pad_sequence
5
  import lightning.pytorch as pl
6
  import config
 
 
 
 
7
  import sys
8
 
9
  sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag")
10
  from data_proc.data_gen import (
11
  positive_generator,
 
12
  negative_generator,
 
 
 
13
  get_mentioned_code,
14
  )
15
 
16
 
17
- ##### General
18
- class ContrastiveLearningDataset(Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def __init__(
20
  self,
21
  data: pd.DataFrame,
@@ -31,728 +76,146 @@ class ContrastiveLearningDataset(Dataset):
31
  return sentence
32
 
33
 
34
- def max_pairwise_sim(sentence1, sentence2, current_df, query_df, sim_df, all_d):
35
- """Returns the maximum ontology similarity score between concept pairs mentioned in sentence1 and sentence2.
36
-
37
- Args:
38
- sentence1: anchor sentence
39
- sentence2: negative sentence
40
- current_df: the dataset where anchor sentence stays
41
- query_df: the union of training and validation sets
42
- dictionary: cardiac-related {concepts: synonyms}
43
- sim_df: the dataset of pairwise ontology similarity score
44
- all_d: the dataset of [concepts, synonyms, list of ancestor concepts]
45
- """
46
- # retrieve concepts from the two sentences
47
- anchor_codes = get_mentioned_code(sentence1, current_df)
48
- other_codes = get_mentioned_code(sentence2, query_df)
49
-
50
- # create snomed-ct code pairs and calculate the score using sim_df
51
- code_pairs = list(zip(anchor_codes, other_codes))
52
- sim_scores = []
53
- for pair in code_pairs:
54
- code1 = pair[0]
55
- code2 = pair[1]
56
- if code1 == code2:
57
- result = len(all_d.loc[all_d["concept"] == code1, "ancestors"].values[0])
58
- sim_scores.append(result)
59
- else:
60
- try:
61
- result = sim_df.loc[
62
- (sim_df["Code1"] == code1) & (sim_df["Code2"] == code2), "score"
63
- ].values[0]
64
- sim_scores.append(result)
65
- except:
66
- result = sim_df.loc[
67
- (sim_df["Code1"] == code2) & (sim_df["Code2"] == code1), "score"
68
- ].values[0]
69
- sim_scores.append(result)
70
- if len(sim_scores) > 0:
71
- return max(sim_scores)
72
- else:
73
- return 0
74
-
75
-
76
- ##### SimCSE
77
- def collate_simcse(batch, tokenizer):
78
- """
79
- Use the first sample in the batch as the anchor,
80
- use the duplicate of anchor as the positive,
81
- use the rest of the batch as negatives.
82
- """
83
- anchor = batch[0] # use the first sample in the batch as anchor
84
- positive = anchor[:] # create a duplicate of anchor as positive
85
- negatives = batch[1:] # everything else as negatives
86
- df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
87
-
88
- anchor_token = tokenizer.encode_plus(
89
- anchor,
90
- return_token_type_ids=False,
91
- return_attention_mask=True,
92
- return_tensors="pt",
93
- )
94
- anchor_row = pd.DataFrame(
95
- {
96
- "label": 0,
97
- "input_ids": anchor_token["input_ids"].tolist(),
98
- "attention_mask": anchor_token["attention_mask"].tolist(),
99
- }
100
- )
101
- df = pd.concat([df, anchor_row])
102
-
103
- pos_token = tokenizer.encode_plus(
104
- positive,
105
- return_token_type_ids=False,
106
- return_attention_mask=True,
107
- return_tensors="pt",
108
- )
109
- pos_row = pd.DataFrame(
110
- {
111
- "label": 1,
112
- "input_ids": pos_token["input_ids"].tolist(),
113
- "attention_mask": pos_token["attention_mask"].tolist(),
114
- }
115
- )
116
- df = pd.concat([df, pos_row])
117
-
118
- for neg in negatives:
119
- neg_token = tokenizer.encode_plus(
120
- neg,
121
- return_token_type_ids=False,
122
- return_attention_mask=True,
123
- return_tensors="pt",
124
- )
125
- neg_row = pd.DataFrame(
126
- {
127
- "label": 2,
128
- "input_ids": neg_token["input_ids"].tolist(),
129
- "attention_mask": neg_token["attention_mask"].tolist(),
130
- }
131
- )
132
- df = pd.concat([df, neg_row])
133
-
134
- label = torch.tensor(df["label"].tolist())
135
-
136
- input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
137
- padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
138
- padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
139
-
140
- attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
141
- padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
142
- padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
143
-
144
- return {
145
- "label": label,
146
- "input_ids": padded_input_ids,
147
- "attention_mask": padded_attention_mask,
148
- }
149
-
150
-
151
- def create_dataloader_simcse(
152
- dataset,
153
- tokenizer,
154
- shuffle,
155
- ):
156
- return DataLoader(
157
- dataset,
158
- batch_size=config.batch_size_simcse,
159
- shuffle=shuffle,
160
- num_workers=config.num_workers,
161
- collate_fn=lambda batch: collate_simcse(
162
- batch,
163
- tokenizer,
164
- ),
165
- )
166
-
167
-
168
- class ContrastiveLearningDataModule_simcse(pl.LightningDataModule):
169
- def __init__(
170
- self,
171
- train_df,
172
- val_df,
173
- tokenizer,
174
- ):
175
- super().__init__()
176
- self.train_df = train_df
177
- self.val_df = val_df
178
- self.tokenizer = tokenizer
179
-
180
- def setup(self, stage=None):
181
- self.train_dataset = ContrastiveLearningDataset(self.train_df)
182
- self.val_dataset = ContrastiveLearningDataset(self.val_df)
183
-
184
- def train_dataloader(self):
185
- return create_dataloader_simcse(
186
- self.train_dataset,
187
- self.tokenizer,
188
- shuffle=True,
189
- )
190
-
191
- def val_dataloader(self):
192
- return create_dataloader_simcse(
193
- self.val_dataset,
194
- self.tokenizer,
195
- shuffle=False,
196
- )
197
-
198
 
199
- ##### SimCSE_w
200
- def collate_simcse_w(
201
- batch,
202
- current_df,
203
- query_df,
204
- tokenizer,
205
- sim_df,
206
- all_d,
207
- ):
208
- """
209
- Anchor: 0
210
- Positive: 1
211
- Negative: 2
212
- """
213
  anchor = batch[0]
214
- positive = anchor[:]
215
- negatives = batch[1:]
216
- df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
217
-
218
- anchor_token = tokenizer.encode_plus(
219
- anchor,
220
- return_token_type_ids=False,
221
- return_attention_mask=True,
222
- return_tensors="pt",
223
- )
224
-
225
- anchor_row = pd.DataFrame(
226
- {
227
- "label": 0,
228
- "input_ids": anchor_token["input_ids"].tolist(),
229
- "attention_mask": anchor_token["attention_mask"].tolist(),
230
- "score": 1,
231
- }
232
- )
233
- df = pd.concat([df, anchor_row])
234
-
235
- pos_token = tokenizer.encode_plus(
236
- positive,
237
- return_token_type_ids=False,
238
- return_attention_mask=True,
239
- return_tensors="pt",
240
- )
241
- pos_row = pd.DataFrame(
242
- {
243
- "label": 1,
244
- "input_ids": pos_token["input_ids"].tolist(),
245
- "attention_mask": pos_token["attention_mask"].tolist(),
246
- "score": 1,
247
- }
248
- )
249
- df = pd.concat([df, pos_row])
250
-
251
- for neg in negatives:
252
- neg_token = tokenizer.encode_plus(
253
- neg,
254
- return_token_type_ids=False,
255
- return_attention_mask=True,
256
- return_tensors="pt",
257
- )
258
- score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
259
- offset = 8
260
- score = score + offset
261
- neg_row = pd.DataFrame(
262
- {
263
- "label": 2,
264
- "input_ids": neg_token["input_ids"].tolist(),
265
- "attention_mask": neg_token["attention_mask"].tolist(),
266
- "score": score,
267
- }
268
- )
269
- df = pd.concat([df, neg_row])
270
-
271
- label = torch.tensor(df["label"].tolist())
272
-
273
- input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
274
- padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
275
- padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
276
-
277
- attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
278
- padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
279
- padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
280
-
281
- score = torch.tensor(df["score"].tolist())
282
-
283
- return {
284
- "label": label,
285
- "input_ids": padded_input_ids,
286
- "attention_mask": padded_attention_mask,
287
- "score": score,
288
- }
289
-
290
-
291
- def create_dataloader_simcse_w(
292
- dataset,
293
- current_df,
294
- query_df,
295
- tokenizer,
296
- sim_df,
297
- all_d,
298
- shuffle,
299
- ):
300
- return DataLoader(
301
- dataset,
302
- batch_size=config.batch_size_simcse,
303
- shuffle=shuffle,
304
- num_workers=config.num_workers,
305
- collate_fn=lambda batch: collate_simcse_w(
306
- batch,
307
- current_df,
308
- query_df,
309
- tokenizer,
310
- sim_df,
311
- all_d,
312
- ),
313
- )
314
-
315
-
316
- class ContrastiveLearningDataModule_simcse_w(pl.LightningDataModule):
317
- def __init__(
318
- self,
319
- train_df,
320
- val_df,
321
- query_df,
322
- tokenizer,
323
- sim_df,
324
- all_d,
325
- ):
326
- super().__init__()
327
- self.train_df = train_df
328
- self.val_df = val_df
329
- self.query_df = query_df
330
- self.tokenizer = tokenizer
331
- self.sim_df = sim_df
332
- self.all_d = all_d
333
-
334
- def setup(self, stage=None):
335
- self.train_dataset = ContrastiveLearningDataset(self.train_df)
336
- self.val_dataset = ContrastiveLearningDataset(self.val_df)
337
-
338
- def train_dataloader(self):
339
- return create_dataloader_simcse_w(
340
- self.train_dataset,
341
- self.train_df,
342
- self.query_df,
343
- self.tokenizer,
344
- self.sim_df,
345
- self.all_d,
346
- shuffle=True,
347
- )
348
-
349
- def val_dataloader(self):
350
- return create_dataloader_simcse_w(
351
- self.val_dataset,
352
- self.val_df,
353
- self.query_df,
354
- self.tokenizer,
355
- self.sim_df,
356
- self.all_d,
357
- shuffle=False,
358
- )
359
-
360
-
361
- ##### Samp
362
- def collate_samp(
363
- sentence,
364
- current_df,
365
- query_df,
366
- tokenizer,
367
- dictionary,
368
- sim_df,
369
- ):
370
-
371
- anchor = sentence[0]
372
- positives = positive_generator(
373
- anchor, current_df, query_df, dictionary, num_pos=config.num_pos
374
- )
375
- negatives = negative_generator(
376
  anchor,
377
  current_df,
378
- query_df,
379
- dictionary,
380
- sim_df,
381
- num_neg=config.num_neg,
382
- )
383
- df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
384
- anchor_token = tokenizer.encode_plus(
385
- anchor,
386
- return_token_type_ids=False,
387
- return_attention_mask=True,
388
- return_tensors="pt",
389
- )
390
-
391
- anchor_row = pd.DataFrame(
392
- {
393
- "label": 0,
394
- "input_ids": anchor_token["input_ids"].tolist(),
395
- "attention_mask": anchor_token["attention_mask"].tolist(),
396
- }
397
- )
398
- df = pd.concat([df, anchor_row])
399
-
400
- for pos in positives:
401
- token = tokenizer.encode_plus(
402
- pos,
403
- return_token_type_ids=False,
404
- return_attention_mask=True,
405
- return_tensors="pt",
406
- )
407
- row = pd.DataFrame(
408
- {
409
- "label": 1,
410
- "input_ids": token["input_ids"].tolist(),
411
- "attention_mask": token["attention_mask"].tolist(),
412
- }
413
- )
414
- df = pd.concat([df, row])
415
-
416
- for neg in negatives:
417
- token = tokenizer.encode_plus(
418
- neg,
419
- return_token_type_ids=False,
420
- return_attention_mask=True,
421
- return_tensors="pt",
422
- )
423
- row = pd.DataFrame(
424
- {
425
- "label": 2,
426
- "input_ids": token["input_ids"].tolist(),
427
- "attention_mask": token["attention_mask"].tolist(),
428
- }
429
- )
430
- df = pd.concat([df, row])
431
-
432
- label = torch.tensor(df["label"].tolist())
433
-
434
- input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
435
- padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
436
- padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
437
-
438
- attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
439
- padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
440
- padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
441
-
442
- return {
443
- "label": label,
444
- "input_ids": padded_input_ids,
445
- "attention_mask": padded_attention_mask,
446
- }
447
-
448
-
449
- def create_dataloader_samp(
450
- dataset,
451
- current_df,
452
- query_df,
453
- tokenizer,
454
- dictionary,
455
- sim_df,
456
- shuffle,
457
- ):
458
- return DataLoader(
459
- dataset,
460
- batch_size=config.batch_size,
461
- shuffle=shuffle,
462
- num_workers=config.num_workers,
463
- collate_fn=lambda batch: collate_samp(
464
- batch,
465
- current_df,
466
- query_df,
467
- tokenizer,
468
- dictionary,
469
- sim_df,
470
- ),
471
- )
472
-
473
-
474
- class ContrastiveLearningDataModule_samp(pl.LightningDataModule):
475
- def __init__(
476
- self,
477
- train_df,
478
- val_df,
479
- query_df,
480
- tokenizer,
481
  dictionary,
482
- sim_df,
483
- ):
484
- super().__init__()
485
- self.train_df = train_df
486
- self.val_df = val_df
487
- self.query_df = query_df
488
- self.tokenizer = tokenizer
489
- self.dictionary = dictionary
490
- self.sim_df = sim_df
491
-
492
- def setup(self, stage=None):
493
- self.train_dataset = ContrastiveLearningDataset(self.train_df)
494
- self.val_dataset = ContrastiveLearningDataset(self.val_df)
495
-
496
- def train_dataloader(self):
497
- return create_dataloader_samp(
498
- self.train_dataset,
499
- self.train_df,
500
- self.query_df,
501
- self.tokenizer,
502
- self.dictionary,
503
- self.sim_df,
504
- shuffle=True,
505
- )
506
-
507
- def val_dataloader(self):
508
- return create_dataloader_samp(
509
- self.val_dataset,
510
- self.val_df,
511
- self.query_df,
512
- self.tokenizer,
513
- self.dictionary,
514
- self.sim_df,
515
- shuffle=False,
516
- )
517
-
518
-
519
- ##### Samp_w
520
- def collate_samp_w(
521
- sentence,
522
- current_df,
523
- query_df,
524
- tokenizer,
525
- dictionary,
526
- sim_df,
527
- all_d,
528
- ):
529
- """
530
- Anchor: 0
531
- Positive: 1
532
- Negative: 2
533
- """
534
- anchor = sentence[0]
535
- positives = positive_generator(
536
- anchor, current_df, query_df, dictionary, num_pos=config.num_pos
537
  )
538
- negatives = negative_generator(
539
  anchor,
540
  current_df,
541
  query_df,
542
- dictionary,
543
- sim_df,
544
  num_neg=config.num_neg,
545
  )
546
- df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
547
- anchor_token = tokenizer.encode_plus(
548
- anchor,
549
- return_token_type_ids=False,
550
- return_attention_mask=True,
551
- return_tensors="pt",
552
- )
553
 
554
- anchor_row = pd.DataFrame(
555
- {
556
- "label": 0,
557
- "input_ids": anchor_token["input_ids"].tolist(),
558
- "attention_mask": anchor_token["attention_mask"].tolist(),
559
- "score": 1,
560
- }
561
- )
562
- df = pd.concat([df, anchor_row])
563
 
564
  for pos in positives:
565
- token = tokenizer.encode_plus(
566
- pos,
567
- return_token_type_ids=False,
568
- return_attention_mask=True,
569
- return_tensors="pt",
570
- )
571
- row = pd.DataFrame(
572
- {
573
- "label": 1,
574
- "input_ids": token["input_ids"].tolist(),
575
- "attention_mask": token["attention_mask"].tolist(),
576
- "score": 1,
577
- }
578
- )
579
- df = pd.concat([df, row])
580
 
581
  for neg in negatives:
582
- token = tokenizer.encode_plus(
583
- neg,
584
- return_token_type_ids=False,
585
- return_attention_mask=True,
586
- return_tensors="pt",
587
- )
588
- score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
589
- offset = 8 # all negative scores start with 8 to distinguish from the positives
590
- score = score + offset
591
- row = pd.DataFrame(
592
- {
593
- "label": 2,
594
- "input_ids": token["input_ids"].tolist(),
595
- "attention_mask": token["attention_mask"].tolist(),
596
- "score": score,
597
- }
598
- )
599
- df = pd.concat([df, row])
600
 
601
- label = torch.tensor(df["label"].tolist())
602
 
603
- input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
604
- padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
605
  padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
606
 
607
- attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
608
  padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
609
  padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
610
 
611
- score = torch.tensor(df["score"].tolist())
 
 
 
 
 
 
612
 
613
  return {
614
- "label": label,
615
  "input_ids": padded_input_ids,
616
  "attention_mask": padded_attention_mask,
617
- "score": score,
 
618
  }
619
 
620
 
621
- def create_dataloader_samp_w(
622
- dataset,
623
- current_df,
624
- query_df,
625
- tokenizer,
626
- dictionary,
627
- sim_df,
628
- all_d,
629
- shuffle,
630
  ):
631
  return DataLoader(
632
  dataset,
633
  batch_size=config.batch_size,
634
  shuffle=shuffle,
635
- num_workers=config.num_workers,
636
- collate_fn=lambda batch: collate_samp_w(
637
- batch,
638
- current_df,
639
- query_df,
640
- tokenizer,
641
- dictionary,
642
- sim_df,
643
- all_d,
644
  ),
645
  )
646
 
647
 
648
- class ContrastiveLearningDataModule_samp_w(pl.LightningDataModule):
649
  def __init__(
650
  self,
651
  train_df,
652
  val_df,
653
- query_df,
654
  tokenizer,
 
655
  dictionary,
656
- sim_df,
657
  all_d,
658
  ):
659
  super().__init__()
660
  self.train_df = train_df
661
  self.val_df = val_df
662
- self.query_df = query_df
663
  self.tokenizer = tokenizer
 
664
  self.dictionary = dictionary
665
- self.sim_df = sim_df
666
  self.all_d = all_d
667
 
668
  def setup(self, stage=None):
669
- self.train_dataset = ContrastiveLearningDataset(self.train_df)
670
- self.val_dataset = ContrastiveLearningDataset(self.val_df)
671
 
672
  def train_dataloader(self):
673
- return create_dataloader_samp_w(
674
  self.train_dataset,
675
- self.train_df,
676
- self.query_df,
677
  self.tokenizer,
678
- self.dictionary,
679
- self.sim_df,
680
- self.all_d,
681
  shuffle=True,
 
 
 
 
682
  )
683
 
684
  def val_dataloader(self):
685
- return create_dataloader_samp_w(
686
  self.val_dataset,
687
- self.val_df,
688
- self.query_df,
689
  self.tokenizer,
690
- self.dictionary,
691
- self.sim_df,
692
- self.all_d,
693
  shuffle=False,
 
 
 
 
694
  )
695
 
696
 
697
- #### Test
698
- from transformers import AutoTokenizer
699
- from ast import literal_eval
700
- from sklearn.model_selection import train_test_split
 
 
 
 
 
 
 
701
 
702
- query_df = pd.read_csv(
703
- "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv"
704
- )
705
- query_df["concepts"] = query_df["concepts"].apply(literal_eval)
706
- query_df["codes"] = query_df["codes"].apply(literal_eval)
707
- query_df["codes"] = query_df["codes"].apply(
708
- lambda x: [val for val in x if val is not None]
709
- ) # remove None in lists
710
- query_df = query_df.drop(columns=["one_hot"])
711
- train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
712
-
713
- tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
714
-
715
- sim_df = pd.read_csv(
716
- "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv"
717
- )
718
 
719
- all_d = pd.read_csv(
720
- "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv"
721
- )
722
- all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
723
- all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
724
- dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
725
-
726
- d1 = ContrastiveLearningDataModule_simcse(train_df, val_df, tokenizer)
727
- d1.setup()
728
- train_d1 = d1.train_dataloader()
729
- for batch in train_d1:
730
- b1 = batch
731
- break
732
-
733
- d2 = ContrastiveLearningDataModule_simcse_w(
734
- train_df, val_df, query_df, tokenizer, sim_df, all_d
735
- )
736
- d2.setup()
737
- train_d2 = d2.train_dataloader()
738
- for batch in train_d2:
739
- b2 = batch
740
- break
741
-
742
- d3 = ContrastiveLearningDataModule_samp(
743
- train_df, val_df, query_df, tokenizer, dictionary, sim_df
744
- )
745
- d3.setup()
746
- train_d3 = d3.train_dataloader()
747
- for batch in train_d3:
748
- b3 = batch
749
- break
750
-
751
- d4 = ContrastiveLearningDataModule_samp_w(
752
- train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d
753
- )
754
- d4.setup()
755
- train_d4 = d4.train_dataloader()
756
- for batch in train_d4:
757
- b4 = batch
758
- break
 
 
1
  import torch
2
+ from transformers import AutoTokenizer
3
  from torch.utils.data import Dataset, DataLoader
4
  from torch.nn.utils.rnn import pad_sequence
5
  import lightning.pytorch as pl
6
  import config
7
+ import pandas as pd
8
+ import copy
9
+ from ast import literal_eval
10
+ from sklearn.model_selection import train_test_split
11
  import sys
12
 
13
  sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag")
14
  from data_proc.data_gen import (
15
  positive_generator,
16
+ positive_generator_alter,
17
  negative_generator,
18
+ negative_generator_alter,
19
+ negative_generator_random,
20
+ negative_generator_v2,
21
  get_mentioned_code,
22
  )
23
 
24
 
25
+ def tokenize(text, tokenizer, tag):
26
+ inputs = tokenizer(
27
+ text,
28
+ return_token_type_ids=False,
29
+ return_tensors="pt",
30
+ )
31
+
32
+ inputs["input_ids"] = inputs["input_ids"][0]
33
+ inputs["attention_mask"] = inputs["attention_mask"][0]
34
+ inputs["mlm_ids"] = copy.deepcopy(inputs["input_ids"])
35
+ inputs["mlm_labels"] = copy.deepcopy(inputs["input_ids"])
36
+
37
+ tokens_to_ignore = torch.tensor([101, 102, 0]) # [CLS], [SEP], [PAD]
38
+ valid_tokens = inputs["input_ids"][
39
+ ~torch.isin(inputs["input_ids"], tokens_to_ignore)
40
+ ]
41
+ num_of_token_to_mask = int(len(valid_tokens) * config.mask_pct)
42
+ token_to_mask = valid_tokens[
43
+ torch.randperm(valid_tokens.size(0))[:num_of_token_to_mask]
44
+ ]
45
+ inputs["mlm_ids"] = [
46
+ 103 if x in token_to_mask else x for x in inputs["mlm_ids"]
47
+ ] # [MASK]
48
+ inputs["mlm_labels"] = [
49
+ y if y in token_to_mask else -100 for y in inputs["mlm_labels"]
50
+ ]
51
+
52
+ inputs["mlm_ids"] = torch.tensor(inputs["mlm_ids"])
53
+ inputs["mlm_labels"] = torch.tensor(inputs["mlm_labels"])
54
+ if tag == "A":
55
+ inputs["tag"] = 0
56
+ elif tag == "P":
57
+ inputs["tag"] = 1
58
+ elif tag == "N":
59
+ inputs["tag"] = 2
60
+ return inputs
61
+
62
+
63
+ class CLDataset(Dataset):
64
  def __init__(
65
  self,
66
  data: pd.DataFrame,
 
76
  return sentence
77
 
78
 
79
+ def collate_func(batch, tokenizer, current_df, query_df, dictionary, all_d):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  anchor = batch[0]
82
+ positives = positive_generator_alter(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  anchor,
84
  current_df,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  dictionary,
86
+ num_pos=config.num_pos,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
+ negatives = negative_generator_v2(
89
  anchor,
90
  current_df,
91
  query_df,
92
+ all_d,
 
93
  num_neg=config.num_neg,
94
  )
 
 
 
 
 
 
 
95
 
96
+ inputs = []
97
+
98
+ anchor_dict = tokenize(anchor, tokenizer, "A")
99
+ inputs.append(anchor_dict)
 
 
 
 
 
100
 
101
  for pos in positives:
102
+ pos_dict = tokenize(pos, tokenizer, "P")
103
+ inputs.append(pos_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  for neg in negatives:
106
+ neg_dict = tokenize(neg, tokenizer, "N")
107
+ inputs.append(neg_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ tags = torch.tensor([d["tag"] for d in inputs])
110
 
111
+ input_ids_tsr = [d["input_ids"] for d in inputs]
112
+ padded_input_ids = pad_sequence(input_ids_tsr, padding_value=0)
113
  padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
114
 
115
+ attention_mask_tsr = [d["attention_mask"] for d in inputs]
116
  padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
117
  padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
118
 
119
+ mlm_ids_tsr = [d["mlm_ids"] for d in inputs]
120
+ padded_mlm_ids = pad_sequence(mlm_ids_tsr, padding_value=0)
121
+ padded_mlm_ids = torch.transpose(padded_mlm_ids, 0, 1)
122
+
123
+ mlm_labels_tsr = [d["mlm_labels"] for d in inputs]
124
+ padded_mlm_labels = pad_sequence(mlm_labels_tsr, padding_value=-100)
125
+ padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1)
126
 
127
  return {
128
+ "tags": tags,
129
  "input_ids": padded_input_ids,
130
  "attention_mask": padded_attention_mask,
131
+ "mlm_ids": padded_mlm_ids,
132
+ "mlm_labels": padded_mlm_labels,
133
  }
134
 
135
 
136
+ def create_dataloader(
137
+ dataset, tokenizer, shuffle, current_df, query_df, dictionary, all_d
 
 
 
 
 
 
 
138
  ):
139
  return DataLoader(
140
  dataset,
141
  batch_size=config.batch_size,
142
  shuffle=shuffle,
143
+ num_workers=1,
144
+ collate_fn=lambda batch: collate_func(
145
+ batch, tokenizer, current_df, query_df, dictionary, all_d
 
 
 
 
 
 
146
  ),
147
  )
148
 
149
 
150
+ class CLDataModule(pl.LightningDataModule):
151
  def __init__(
152
  self,
153
  train_df,
154
  val_df,
 
155
  tokenizer,
156
+ query_df,
157
  dictionary,
 
158
  all_d,
159
  ):
160
  super().__init__()
161
  self.train_df = train_df
162
  self.val_df = val_df
 
163
  self.tokenizer = tokenizer
164
+ self.query_df = query_df
165
  self.dictionary = dictionary
 
166
  self.all_d = all_d
167
 
168
  def setup(self, stage=None):
169
+ self.train_dataset = CLDataset(self.train_df)
170
+ self.val_dataset = CLDataset(self.val_df)
171
 
172
  def train_dataloader(self):
173
+ return create_dataloader(
174
  self.train_dataset,
 
 
175
  self.tokenizer,
 
 
 
176
  shuffle=True,
177
+ current_df=self.train_df,
178
+ query_df=self.query_df,
179
+ dictionary=self.dictionary,
180
+ all_d=self.all_d,
181
  )
182
 
183
  def val_dataloader(self):
184
+ return create_dataloader(
185
  self.val_dataset,
 
 
186
  self.tokenizer,
 
 
 
187
  shuffle=False,
188
+ current_df=self.val_df,
189
+ query_df=self.query_df,
190
+ dictionary=self.dictionary,
191
+ all_d=self.all_d,
192
  )
193
 
194
 
195
+ if __name__ == "__main__":
196
+ query_df = pd.read_csv(
197
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
198
+ )
199
+ query_df["concepts"] = query_df["concepts"].apply(literal_eval)
200
+ query_df["codes"] = query_df["codes"].apply(literal_eval)
201
+ query_df["codes"] = query_df["codes"].apply(
202
+ lambda x: [val for val in x if val is not None]
203
+ )
204
+ train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
205
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
206
 
207
+ all_d = pd.read_csv(
208
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
209
+ )
210
+ all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
211
+ all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
212
+ all_d["finding_sites"] = all_d["finding_sites"].apply(literal_eval)
213
+ all_d["morphology"] = all_d["morphology"].apply(literal_eval)
214
+ dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
 
 
 
 
 
 
 
 
215
 
216
+ d = CLDataModule(train_df, val_df, tokenizer, query_df, dictionary, all_d)
217
+ d.setup()
218
+ train = d.train_dataloader()
219
+ for batch in train:
220
+ b = batch
221
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loss.py CHANGED
@@ -4,124 +4,17 @@ import torch.nn.functional as F
4
  import config
5
 
6
 
7
- class ContrastiveLoss_simcse(nn.Module):
8
- """SimCSE loss"""
9
-
10
- def __init__(self):
11
- super(ContrastiveLoss_simcse, self).__init__()
12
- self.temperature = config.temperature
13
-
14
- def forward(self, feature_vectors, labels):
15
- normalized_features = F.normalize(
16
- feature_vectors, p=2, dim=0
17
- ) # normalize along columns
18
-
19
- # Identify indices for each label
20
- anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
21
- positive_indices = (labels == 1).nonzero().squeeze(dim=1)
22
- negative_indices = (labels == 2).nonzero().squeeze(dim=1)
23
-
24
- # Extract tensors based on labels
25
- anchor = normalized_features[anchor_indices]
26
- positives = normalized_features[positive_indices]
27
- negatives = normalized_features[negative_indices]
28
- pos_and_neg = torch.cat([positives, negatives])
29
-
30
- denominator = torch.sum(
31
- torch.exp(
32
- torch.div(
33
- torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
34
- self.temperature,
35
- )
36
- )
37
- )
38
-
39
- numerator = torch.exp(
40
- torch.div(
41
- torch.matmul(anchor, torch.transpose(positives, 0, 1)),
42
- self.temperature,
43
- )
44
- )
45
-
46
- loss = -torch.log(
47
- torch.div(
48
- numerator,
49
- denominator,
50
- )
51
- )
52
-
53
- return loss
54
-
55
-
56
- class ContrastiveLoss_simcse_w(nn.Module):
57
- """SimCSE loss with weighting."""
58
-
59
- def __init__(self):
60
- super(ContrastiveLoss_simcse_w, self).__init__()
61
- self.temperature = config.temperature
62
-
63
- def forward(self, feature_vectors, labels, scores):
64
- normalized_features = F.normalize(
65
- feature_vectors, p=2, dim=0
66
- ) # normalize along columns
67
-
68
- # Identify indices for each label
69
- anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
70
- positive_indices = (labels == 1).nonzero().squeeze(dim=1)
71
- negative_indices = (labels == 2).nonzero().squeeze(dim=1)
72
-
73
- pos_scores = scores[positive_indices].float()
74
- normalized_neg_scores = F.normalize(
75
- scores[negative_indices].float(), p=2, dim=0
76
- ) # l2-norm
77
- normalized_neg_scores += 1
78
- scores = torch.cat([pos_scores, normalized_neg_scores])
79
-
80
- # Extract tensors based on labels
81
- anchor = normalized_features[anchor_indices]
82
- positives = normalized_features[positive_indices]
83
- negatives = normalized_features[negative_indices]
84
- pos_and_neg = torch.cat([positives, negatives])
85
-
86
- denominator = torch.sum(
87
- torch.exp(
88
- scores
89
- * torch.div(
90
- torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
91
- self.temperature,
92
- )
93
- )
94
- )
95
-
96
- numerator = torch.exp(
97
- torch.div(
98
- torch.matmul(anchor, torch.transpose(positives, 0, 1)),
99
- self.temperature,
100
- )
101
- )
102
-
103
- loss = -torch.log(
104
- torch.div(
105
- numerator,
106
- denominator,
107
- )
108
- )
109
-
110
- return loss
111
-
112
-
113
- class ContrastiveLoss_samp(nn.Module):
114
  """Supervised contrastive loss without weighting."""
115
 
116
  def __init__(self):
117
- super(ContrastiveLoss_samp, self).__init__()
118
  self.temperature = config.temperature
119
 
120
  def forward(self, feature_vectors, labels):
121
- # Normalize feature vectors
122
  normalized_features = F.normalize(
123
- feature_vectors, p=2, dim=0
124
- ) # normalize along columns
125
 
126
  # Identify indices for each label
127
  anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
@@ -139,82 +32,35 @@ class ContrastiveLoss_samp(nn.Module):
139
  denominator = torch.sum(
140
  torch.exp(
141
  torch.div(
142
- torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
143
  self.temperature,
144
  )
145
  )
146
  )
147
 
148
- sum_log_ent = torch.sum(
149
- torch.log(
150
- torch.div(
151
- torch.exp(
152
- torch.div(
153
- torch.matmul(anchor, torch.transpose(positives, 0, 1)),
154
- self.temperature,
155
- )
156
- ),
157
- denominator,
158
- )
159
- )
160
- )
161
-
162
- scale = -1 / pos_cardinal
163
-
164
- return scale * sum_log_ent
165
-
166
-
167
- class ContrastiveLoss_samp_w(nn.Module):
168
- """Supervised contrastive loss with weighting."""
169
-
170
- def __init__(self):
171
- super(ContrastiveLoss_samp_w, self).__init__()
172
- self.temperature = config.temperature
173
-
174
- def forward(self, feature_vectors, labels, scores):
175
- # Normalize feature vectors
176
- normalized_features = F.normalize(
177
- feature_vectors, p=2, dim=0
178
- ) # normalize along columns
179
-
180
- # Identify indices for each label
181
- anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
182
- positive_indices = (labels == 1).nonzero().squeeze(dim=1)
183
- negative_indices = (labels == 2).nonzero().squeeze(dim=1)
184
-
185
- # Normalize score vector
186
- num_skip = len(positive_indices) + 1
187
- pos_scores = scores[: (num_skip - 1)].float() # exclude anchor
188
- normalized_neg_scores = F.normalize(
189
- scores[num_skip:].float(), p=2, dim=0
190
- ) # l2-norm
191
- normalized_neg_scores += 1
192
- scores = torch.cat([pos_scores, normalized_neg_scores])
193
-
194
- # Extract tensors based on labels
195
- anchor = normalized_features[anchor_indices]
196
- positives = normalized_features[positive_indices]
197
- negatives = normalized_features[negative_indices]
198
- pos_and_neg = torch.cat([positives, negatives])
199
-
200
- pos_cardinal = positives.shape[0]
201
-
202
- denominator = torch.sum(
203
- torch.exp(
204
- scores
205
- * torch.div(
206
- torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
207
- self.temperature,
208
- )
209
- )
210
- )
211
 
212
  sum_log_ent = torch.sum(
213
  torch.log(
214
  torch.div(
215
  torch.exp(
216
  torch.div(
217
- torch.matmul(anchor, torch.transpose(positives, 0, 1)),
218
  self.temperature,
219
  )
220
  ),
@@ -224,5 +70,6 @@ class ContrastiveLoss_samp_w(nn.Module):
224
  )
225
 
226
  scale = -1 / pos_cardinal
 
227
 
228
- return scale * sum_log_ent
 
4
  import config
5
 
6
 
7
+ class CL_loss(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """Supervised contrastive loss without weighting."""
9
 
10
  def __init__(self):
11
+ super(CL_loss, self).__init__()
12
  self.temperature = config.temperature
13
 
14
  def forward(self, feature_vectors, labels):
 
15
  normalized_features = F.normalize(
16
+ feature_vectors, p=2, dim=1
17
+ ) # normalize by row, each row euc is approximately 1
18
 
19
  # Identify indices for each label
20
  anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
 
32
  denominator = torch.sum(
33
  torch.exp(
34
  torch.div(
35
+ F.cosine_similarity(anchor, pos_and_neg, dim=1),
36
  self.temperature,
37
  )
38
  )
39
  )
40
 
41
+ # if not torch.isfinite(denominator):
42
+ # print("Denominator is Inf!")
43
+
44
+ # if not torch.isfinite(
45
+ # torch.exp(
46
+ # torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
47
+ # self.temperature,
48
+ # )
49
+ # ).all():
50
+ # print("Exp is Inf!")
51
+ # print(
52
+ # torch.exp(
53
+ # torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
54
+ # self.temperature,
55
+ # )
56
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  sum_log_ent = torch.sum(
59
  torch.log(
60
  torch.div(
61
  torch.exp(
62
  torch.div(
63
+ F.cosine_similarity(anchor, positives, dim=1),
64
  self.temperature,
65
  )
66
  ),
 
70
  )
71
 
72
  scale = -1 / pos_cardinal
73
+ out = scale * sum_log_ent
74
 
75
+ return out
model.py CHANGED
@@ -2,31 +2,35 @@ import lightning.pytorch as pl
2
  from transformers import (
3
  AdamW,
4
  AutoModel,
 
5
  get_linear_schedule_with_warmup,
6
  )
 
7
  import torch
8
  from torch import nn
9
- from loss import (
10
- ContrastiveLoss_simcse,
11
- ContrastiveLoss_simcse_w,
12
- ContrastiveLoss_samp,
13
- ContrastiveLoss_samp_w,
14
- )
15
 
16
 
17
- class BERTContrastiveLearning_simcse(pl.LightningModule):
18
- def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
 
 
19
  super().__init__()
20
- ### Parameters
 
21
  self.n_batches = n_batches
22
  self.n_epochs = n_epochs
23
  self.lr = lr
 
 
 
24
 
25
- ### Architecture
26
  self.bert = AutoModel.from_pretrained(
27
  "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
28
  )
29
- # Unfreeze encoder
30
  self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
31
  self.num_unfreeze_layer = self.bert_layer_num
32
  self.ratio_unfreeze_layer = 0.0
@@ -43,378 +47,138 @@ class BERTContrastiveLearning_simcse(pl.LightningModule):
43
  )
44
  for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
45
  param.requires_grad = False
46
- # Random dropouts
47
- self.dropout1 = nn.Dropout(p=0.1)
48
- self.dropout2 = nn.Dropout(p=0.1)
49
- # Linear projector
50
  self.projector = nn.Linear(self.bert.config.hidden_size, 128)
51
  print("Model Initialized!")
52
 
53
- ### Loss
54
- self.criterion = ContrastiveLoss_simcse()
55
-
56
- ### Logs
57
- self.train_loss, self.val_loss, self.test_loss = [], [], []
58
- self.training_step_outputs = []
59
- self.validation_step_outputs = []
60
-
61
- def configure_optimizers(self):
62
- # Optimizer
63
- self.trainable_params = [
64
- param for param in self.parameters() if param.requires_grad
65
- ]
66
- optimizer = AdamW(self.trainable_params, lr=self.lr)
67
-
68
- # Scheduler
69
- # warmup_steps = self.n_batches // 3
70
- # total_steps = self.n_batches * self.n_epochs - warmup_steps
71
- # scheduler = get_linear_schedule_with_warmup(
72
- # optimizer, warmup_steps, total_steps
73
- # )
74
- return [optimizer]
75
-
76
- def forward(self, input_ids, attention_mask):
77
- emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
78
- cls = emb.pooler_output
79
- out = self.projector(cls)
80
- anchor_out = self.dropout1(out[0:1])
81
- rest_out = self.dropout2(out[1:])
82
- output = torch.cat([anchor_out, rest_out])
83
- return cls, output
84
 
85
  def training_step(self, batch, batch_idx):
86
- label = batch["label"]
87
  input_ids = batch["input_ids"]
88
  attention_mask = batch["attention_mask"]
89
- cls, out = self(
90
- input_ids,
91
- attention_mask,
92
- )
93
- loss = self.criterion(out, label)
94
- logs = {"loss": loss}
 
95
  self.training_step_outputs.append(logs)
96
  self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
97
  return loss
98
 
99
  def on_train_epoch_end(self):
100
- loss = (
101
  torch.stack([x["loss"] for x in self.training_step_outputs])
102
  .mean()
103
  .detach()
104
  .cpu()
105
  .numpy()
106
  )
107
- self.train_loss.append(loss)
108
- print("train_epoch:", self.current_epoch, "avg_loss:", loss)
109
- self.training_step_outputs.clear()
110
-
111
- def validation_step(self, batch, batch_idx):
112
- label = batch["label"]
113
- input_ids = batch["input_ids"]
114
- attention_mask = batch["attention_mask"]
115
- cls, out = self(
116
- input_ids,
117
- attention_mask,
118
- )
119
- loss = self.criterion(out, label)
120
- logs = {"loss": loss}
121
- self.validation_step_outputs.append(logs)
122
- self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
123
- return loss
124
-
125
- def on_validation_epoch_end(self):
126
- loss = (
127
- torch.stack([x["loss"] for x in self.validation_step_outputs])
128
  .mean()
129
  .detach()
130
  .cpu()
131
  .numpy()
132
  )
133
- self.val_loss.append(loss)
134
- print("val_epoch:", self.current_epoch, "avg_loss:", loss)
135
- self.validation_step_outputs.clear()
136
-
137
-
138
- class BERTContrastiveLearning_simcse_w(pl.LightningModule):
139
- def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
140
- super().__init__()
141
- ### Parameters
142
- self.n_batches = n_batches
143
- self.n_epochs = n_epochs
144
- self.lr = lr
145
-
146
- ### Architecture
147
- self.bert = AutoModel.from_pretrained(
148
- "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
149
- )
150
- # Unfreeze encoder
151
- self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
152
- self.num_unfreeze_layer = self.bert_layer_num
153
- self.ratio_unfreeze_layer = 0.0
154
- if kwargs:
155
- for key, value in kwargs.items():
156
- if key == "unfreeze" and isinstance(value, float):
157
- assert (
158
- value >= 0.0 and value <= 1.0
159
- ), "ValueError: value must be a ratio between 0.0 and 1.0"
160
- self.ratio_unfreeze_layer = value
161
- if self.ratio_unfreeze_layer > 0.0:
162
- self.num_unfreeze_layer = int(
163
- self.bert_layer_num * self.ratio_unfreeze_layer
164
- )
165
- for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
166
- param.requires_grad = False
167
- # Random dropouts
168
- self.dropout1 = nn.Dropout(p=0.1)
169
- self.dropout2 = nn.Dropout(p=0.1)
170
- # Linear projector
171
- self.projector = nn.Linear(self.bert.config.hidden_size, 128)
172
- print("Model Initialized!")
173
-
174
- ### Loss
175
- self.criterion = ContrastiveLoss_simcse_w()
176
-
177
- ### Logs
178
- self.train_loss, self.val_loss, self.test_loss = [], [], []
179
- self.training_step_outputs = []
180
- self.validation_step_outputs = []
181
-
182
- def configure_optimizers(self):
183
- # Optimizer
184
- self.trainable_params = [
185
- param for param in self.parameters() if param.requires_grad
186
- ]
187
- optimizer = AdamW(self.trainable_params, lr=self.lr)
188
-
189
- # Scheduler
190
- # warmup_steps = self.n_batches // 3
191
- # total_steps = self.n_batches * self.n_epochs - warmup_steps
192
- # scheduler = get_linear_schedule_with_warmup(
193
- # optimizer, warmup_steps, total_steps
194
- # )
195
- return [optimizer]
196
-
197
- def forward(self, input_ids, attention_mask):
198
- emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
199
- cls = emb.pooler_output
200
- out = self.projector(cls)
201
- anchor_out = self.dropout1(out[0:1])
202
- rest_out = self.dropout2(out[1:])
203
- output = torch.cat([anchor_out, rest_out])
204
- return cls, output
205
-
206
- def training_step(self, batch, batch_idx):
207
- label = batch["label"]
208
- input_ids = batch["input_ids"]
209
- attention_mask = batch["attention_mask"]
210
- score = batch["score"]
211
- cls, out = self(
212
- input_ids,
213
- attention_mask,
214
- )
215
- loss = self.criterion(out, label, score)
216
- logs = {"loss": loss}
217
- self.training_step_outputs.append(logs)
218
- self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
219
- return loss
220
-
221
- def on_train_epoch_end(self):
222
- loss = (
223
- torch.stack([x["loss"] for x in self.training_step_outputs])
224
  .mean()
225
  .detach()
226
  .cpu()
227
  .numpy()
228
  )
229
- self.train_loss.append(loss)
230
- print("train_epoch:", self.current_epoch, "avg_loss:", loss)
 
 
 
 
 
 
 
 
 
231
  self.training_step_outputs.clear()
232
 
233
  def validation_step(self, batch, batch_idx):
234
- label = batch["label"]
235
  input_ids = batch["input_ids"]
236
  attention_mask = batch["attention_mask"]
237
- score = batch["score"]
238
- cls, out = self(
239
- input_ids,
240
- attention_mask,
241
- )
242
- loss = self.criterion(out, label, score)
243
- logs = {"loss": loss}
244
  self.validation_step_outputs.append(logs)
245
  self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
246
  return loss
247
 
248
  def on_validation_epoch_end(self):
249
- loss = (
250
  torch.stack([x["loss"] for x in self.validation_step_outputs])
251
  .mean()
252
  .detach()
253
  .cpu()
254
  .numpy()
255
  )
256
- self.val_loss.append(loss)
257
- print("val_epoch:", self.current_epoch, "avg_loss:", loss)
258
- self.validation_step_outputs.clear()
259
-
260
-
261
- class BERTContrastiveLearning_samp(pl.LightningModule):
262
-
263
- def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
264
- super().__init__()
265
- ### Parameters
266
- self.n_batches = n_batches
267
- self.n_epochs = n_epochs
268
- self.lr = lr
269
-
270
- ### Architecture
271
- self.bert = AutoModel.from_pretrained(
272
- "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
273
- )
274
- # Unfreeze encoder
275
- self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
276
- self.num_unfreeze_layer = self.bert_layer_num
277
- self.ratio_unfreeze_layer = 0.0
278
- if kwargs:
279
- for key, value in kwargs.items():
280
- if key == "unfreeze" and isinstance(value, float):
281
- assert (
282
- value >= 0.0 and value <= 1.0
283
- ), "ValueError: value must be a ratio between 0.0 and 1.0"
284
- self.ratio_unfreeze_layer = value
285
- if self.ratio_unfreeze_layer > 0.0:
286
- self.num_unfreeze_layer = int(
287
- self.bert_layer_num * self.ratio_unfreeze_layer
288
- )
289
- for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
290
- param.requires_grad = False
291
- # Linear projector
292
- self.projector = nn.Linear(self.bert.config.hidden_size, 128)
293
- print("Model Initialized!")
294
-
295
- ### Loss
296
- self.criterion = ContrastiveLoss_samp()
297
-
298
- ### Logs
299
- self.train_loss, self.val_loss, self.test_loss = [], [], []
300
- self.training_step_outputs = []
301
- self.validation_step_outputs = []
302
-
303
- def configure_optimizers(self):
304
- # Optimizer
305
- self.trainable_params = [
306
- param for param in self.parameters() if param.requires_grad
307
- ]
308
- optimizer = AdamW(self.trainable_params, lr=self.lr)
309
-
310
- # Scheduler
311
- # warmup_steps = self.n_batches // 3
312
- # total_steps = self.n_batches * self.n_epochs - warmup_steps
313
- # scheduler = get_linear_schedule_with_warmup(
314
- # optimizer, warmup_steps, total_steps
315
- # )
316
- return [optimizer]
317
-
318
- def forward(self, input_ids, attention_mask):
319
- emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
320
- cls = emb.pooler_output
321
- out = self.projector(cls)
322
- return cls, out
323
-
324
- def training_step(self, batch, batch_idx):
325
- label = batch["label"]
326
- input_ids = batch["input_ids"]
327
- attention_mask = batch["attention_mask"]
328
- cls, out = self(
329
- input_ids,
330
- attention_mask,
331
- )
332
- loss = self.criterion(out, label)
333
- logs = {"loss": loss}
334
- self.training_step_outputs.append(logs)
335
- self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
336
- return loss
337
-
338
- def on_train_epoch_end(self):
339
- loss = (
340
- torch.stack([x["loss"] for x in self.training_step_outputs])
341
  .mean()
342
  .detach()
343
  .cpu()
344
  .numpy()
345
  )
346
- self.train_loss.append(loss)
347
- print("train_epoch:", self.current_epoch, "avg_loss:", loss)
348
- self.training_step_outputs.clear()
349
-
350
- def validation_step(self, batch, batch_idx):
351
- label = batch["label"]
352
- input_ids = batch["input_ids"]
353
- attention_mask = batch["attention_mask"]
354
- cls, out = self(
355
- input_ids,
356
- attention_mask,
357
- )
358
- loss = self.criterion(out, label)
359
- logs = {"loss": loss}
360
- self.validation_step_outputs.append(logs)
361
- self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
362
- return loss
363
-
364
- def on_validation_epoch_end(self):
365
- loss = (
366
- torch.stack([x["loss"] for x in self.validation_step_outputs])
367
  .mean()
368
  .detach()
369
  .cpu()
370
  .numpy()
371
  )
372
- self.val_loss.append(loss)
373
- print("val_epoch:", self.current_epoch, "avg_loss:", loss)
374
- self.validation_step_outputs.clear()
375
-
376
-
377
- class BERTContrastiveLearning_samp_w(pl.LightningModule):
378
-
379
- def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
380
- super().__init__()
381
- ### Parameters
382
- self.n_batches = n_batches
383
- self.n_epochs = n_epochs
384
- self.lr = lr
385
-
386
- ### Architecture
387
- self.bert = AutoModel.from_pretrained(
388
- "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
389
  )
390
- # Unfreeze encoder
391
- self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
392
- self.num_unfreeze_layer = self.bert_layer_num
393
- self.ratio_unfreeze_layer = 0.0
394
- if kwargs:
395
- for key, value in kwargs.items():
396
- if key == "unfreeze" and isinstance(value, float):
397
- assert (
398
- value >= 0.0 and value <= 1.0
399
- ), "ValueError: value must be a ratio between 0.0 and 1.0"
400
- self.ratio_unfreeze_layer = value
401
- if self.ratio_unfreeze_layer > 0.0:
402
- self.num_unfreeze_layer = int(
403
- self.bert_layer_num * self.ratio_unfreeze_layer
404
- )
405
- for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
406
- param.requires_grad = False
407
- # Linear projector
408
- self.projector = nn.Linear(self.bert.config.hidden_size, 128)
409
- print("Model Initialized!")
410
-
411
- ### Loss
412
- self.criterion = ContrastiveLoss_samp_w()
413
-
414
- ### Logs
415
- self.train_loss, self.val_loss, self.test_loss = [], [], []
416
- self.training_step_outputs = []
417
- self.validation_step_outputs = []
418
 
419
  def configure_optimizers(self):
420
  # Optimizer
@@ -424,69 +188,9 @@ class BERTContrastiveLearning_samp_w(pl.LightningModule):
424
  optimizer = AdamW(self.trainable_params, lr=self.lr)
425
 
426
  # Scheduler
427
- # warmup_steps = self.n_batches // 3
428
- # total_steps = self.n_batches * self.n_epochs - warmup_steps
429
- # scheduler = get_linear_schedule_with_warmup(
430
- # optimizer, warmup_steps, total_steps
431
- # )
432
- return [optimizer]
433
-
434
- def forward(self, input_ids, attention_mask):
435
- emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
436
- cls = emb.pooler_output
437
- out = self.projector(cls)
438
- return cls, out
439
-
440
- def training_step(self, batch, batch_idx):
441
- label = batch["label"]
442
- input_ids = batch["input_ids"]
443
- attention_mask = batch["attention_mask"]
444
- score = batch["score"]
445
- cls, out = self(
446
- input_ids,
447
- attention_mask,
448
  )
449
- loss = self.criterion(out, label, score)
450
- logs = {"loss": loss}
451
- self.training_step_outputs.append(logs)
452
- self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
453
- return loss
454
-
455
- def on_train_epoch_end(self):
456
- loss = (
457
- torch.stack([x["loss"] for x in self.training_step_outputs])
458
- .mean()
459
- .detach()
460
- .cpu()
461
- .numpy()
462
- )
463
- self.train_loss.append(loss)
464
- print("train_epoch:", self.current_epoch, "avg_loss:", loss)
465
- self.training_step_outputs.clear()
466
-
467
- def validation_step(self, batch, batch_idx):
468
- label = batch["label"]
469
- input_ids = batch["input_ids"]
470
- attention_mask = batch["attention_mask"]
471
- score = batch["score"]
472
- cls, out = self(
473
- input_ids,
474
- attention_mask,
475
- )
476
- loss = self.criterion(out, label, score)
477
- logs = {"loss": loss}
478
- self.validation_step_outputs.append(logs)
479
- self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
480
- return loss
481
-
482
- def on_validation_epoch_end(self):
483
- loss = (
484
- torch.stack([x["loss"] for x in self.validation_step_outputs])
485
- .mean()
486
- .detach()
487
- .cpu()
488
- .numpy()
489
- )
490
- self.val_loss.append(loss)
491
- print("val_epoch:", self.current_epoch, "avg_loss:", loss)
492
- self.validation_step_outputs.clear()
 
2
  from transformers import (
3
  AdamW,
4
  AutoModel,
5
+ AutoConfig,
6
  get_linear_schedule_with_warmup,
7
  )
8
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead
9
  import torch
10
  from torch import nn
11
+ from loss import CL_loss
12
+ import pandas as pd
 
 
 
 
13
 
14
 
15
+ class CL_model(pl.LightningModule):
16
+ def __init__(
17
+ self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs
18
+ ):
19
  super().__init__()
20
+
21
+ ## Params
22
  self.n_batches = n_batches
23
  self.n_epochs = n_epochs
24
  self.lr = lr
25
+ self.mlm_weight = mlm_weight
26
+ # self.first_neg_idx = 0
27
+ self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
28
 
29
+ ## Encoder
30
  self.bert = AutoModel.from_pretrained(
31
  "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
32
  )
33
+ # Unfreeze layers
34
  self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
35
  self.num_unfreeze_layer = self.bert_layer_num
36
  self.ratio_unfreeze_layer = 0.0
 
47
  )
48
  for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
49
  param.requires_grad = False
50
+
51
+ self.lm_head = BertLMPredictionHead(self.config)
 
 
52
  self.projector = nn.Linear(self.bert.config.hidden_size, 128)
53
  print("Model Initialized!")
54
 
55
+ ## Losses
56
+ self.cl_loss = CL_loss()
57
+ self.mlm_loss = nn.CrossEntropyLoss()
58
+
59
+ ## Logs
60
+ self.train_loss, self.val_loss = [], []
61
+ self.train_cl_loss, self.val_cl_loss = [], []
62
+ self.train_mlm_loss, self.val_mlm_loss = [], []
63
+ self.training_step_outputs, self.validation_step_outputs = [], []
64
+
65
+ def forward(self, input_ids, attention_mask, mlm_ids, eval=False):
66
+ # Contrastive
67
+ unmasked = self.bert(input_ids=input_ids, attention_mask=attention_mask)
68
+ cls = unmasked.pooler_output
69
+ if eval is True:
70
+ return cls
71
+ output = self.projector(cls)
72
+
73
+ # MLM
74
+ masked = self.bert(input_ids=mlm_ids, attention_mask=attention_mask)
75
+ pred = self.lm_head(masked.last_hidden_state)
76
+ pred = pred.view(-1, self.config.vocab_size)
77
+ return cls, output, pred
 
 
 
 
 
 
 
 
78
 
79
  def training_step(self, batch, batch_idx):
80
+ tags = batch["tags"]
81
  input_ids = batch["input_ids"]
82
  attention_mask = batch["attention_mask"]
83
+ mlm_ids = batch["mlm_ids"]
84
+ mlm_labels = batch["mlm_labels"].reshape(-1)
85
+ cls, output, pred = self(input_ids, attention_mask, mlm_ids)
86
+ loss_cl = self.cl_loss(output, tags)
87
+ loss_mlm = self.mlm_loss(pred, mlm_labels)
88
+ loss = loss_cl + self.mlm_weight * loss_mlm
89
+ logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
90
  self.training_step_outputs.append(logs)
91
  self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
92
  return loss
93
 
94
  def on_train_epoch_end(self):
95
+ avg_loss = (
96
  torch.stack([x["loss"] for x in self.training_step_outputs])
97
  .mean()
98
  .detach()
99
  .cpu()
100
  .numpy()
101
  )
102
+ self.train_loss.append(avg_loss)
103
+ avg_cl_loss = (
104
+ torch.stack([x["loss_cl"] for x in self.training_step_outputs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  .mean()
106
  .detach()
107
  .cpu()
108
  .numpy()
109
  )
110
+ self.train_cl_loss.append(avg_cl_loss)
111
+ avg_mlm_loss = (
112
+ torch.stack([x["loss_mlm"] for x in self.training_step_outputs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  .mean()
114
  .detach()
115
  .cpu()
116
  .numpy()
117
  )
118
+ self.train_mlm_loss.append(avg_mlm_loss)
119
+ print(
120
+ "train_epoch:",
121
+ self.current_epoch,
122
+ "avg_loss:",
123
+ avg_loss,
124
+ "avg_cl_loss:",
125
+ avg_cl_loss,
126
+ "avg_mlm_loss:",
127
+ avg_mlm_loss,
128
+ )
129
  self.training_step_outputs.clear()
130
 
131
  def validation_step(self, batch, batch_idx):
132
+ tags = batch["tags"]
133
  input_ids = batch["input_ids"]
134
  attention_mask = batch["attention_mask"]
135
+ mlm_ids = batch["mlm_ids"]
136
+ mlm_labels = batch["mlm_labels"].reshape(-1)
137
+ cls, output, pred = self(input_ids, attention_mask, mlm_ids)
138
+ loss_cl = self.cl_loss(output, tags)
139
+ loss_mlm = self.mlm_loss(pred, mlm_labels)
140
+ loss = loss_cl + self.mlm_weight * loss_mlm
141
+ logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
142
  self.validation_step_outputs.append(logs)
143
  self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
144
  return loss
145
 
146
  def on_validation_epoch_end(self):
147
+ avg_loss = (
148
  torch.stack([x["loss"] for x in self.validation_step_outputs])
149
  .mean()
150
  .detach()
151
  .cpu()
152
  .numpy()
153
  )
154
+ self.val_loss.append(avg_loss)
155
+ avg_cl_loss = (
156
+ torch.stack([x["loss_cl"] for x in self.validation_step_outputs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  .mean()
158
  .detach()
159
  .cpu()
160
  .numpy()
161
  )
162
+ self.val_cl_loss.append(avg_cl_loss)
163
+ avg_mlm_loss = (
164
+ torch.stack([x["loss_mlm"] for x in self.validation_step_outputs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  .mean()
166
  .detach()
167
  .cpu()
168
  .numpy()
169
  )
170
+ self.val_mlm_loss.append(avg_mlm_loss)
171
+ print(
172
+ "val_epoch:",
173
+ self.current_epoch,
174
+ "avg_loss:",
175
+ avg_loss,
176
+ "avg_cl_loss:",
177
+ avg_cl_loss,
178
+ "avg_mlm_loss:",
179
+ avg_mlm_loss,
 
 
 
 
 
 
 
180
  )
181
+ self.validation_step_outputs.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def configure_optimizers(self):
184
  # Optimizer
 
188
  optimizer = AdamW(self.trainable_params, lr=self.lr)
189
 
190
  # Scheduler
191
+ warmup_steps = self.n_batches // 3
192
+ total_steps = self.n_batches * self.n_epochs - warmup_steps
193
+ scheduler = get_linear_schedule_with_warmup(
194
+ optimizer, warmup_steps, total_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
+ return [optimizer], [scheduler]