sxtforreal commited on
Commit
504db9e
1 Parent(s): 4f3b260

Create dataset.py

Browse files

This file holds 4 dataset modules for the 4 models respectively: SimCSE, SimCSE_w, Samp, Samp_w.
Run the test at the end to see what's in each training batch.

Files changed (1) hide show
  1. dataset.py +758 -0
dataset.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import lightning.pytorch as pl
6
+ import config
7
+ import sys
8
+
9
+ sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag")
10
+ from data_proc.data_gen import (
11
+ positive_generator,
12
+ negative_generator,
13
+ get_mentioned_code,
14
+ )
15
+
16
+
17
+ ##### General
18
+ class ContrastiveLearningDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ data: pd.DataFrame,
22
+ ):
23
+ self.data = data
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+ def __getitem__(self, index):
29
+ data_row = self.data.iloc[index]
30
+ sentence = data_row.sentences
31
+ return sentence
32
+
33
+
34
+ def max_pairwise_sim(sentence1, sentence2, current_df, query_df, sim_df, all_d):
35
+ """Returns the maximum ontology similarity score between concept pairs mentioned in sentence1 and sentence2.
36
+
37
+ Args:
38
+ sentence1: anchor sentence
39
+ sentence2: negative sentence
40
+ current_df: the dataset where anchor sentence stays
41
+ query_df: the union of training and validation sets
42
+ dictionary: cardiac-related {concepts: synonyms}
43
+ sim_df: the dataset of pairwise ontology similarity score
44
+ all_d: the dataset of [concepts, synonyms, list of ancestor concepts]
45
+ """
46
+ # retrieve concepts from the two sentences
47
+ anchor_codes = get_mentioned_code(sentence1, current_df)
48
+ other_codes = get_mentioned_code(sentence2, query_df)
49
+
50
+ # create snomed-ct code pairs and calculate the score using sim_df
51
+ code_pairs = list(zip(anchor_codes, other_codes))
52
+ sim_scores = []
53
+ for pair in code_pairs:
54
+ code1 = pair[0]
55
+ code2 = pair[1]
56
+ if code1 == code2:
57
+ result = len(all_d.loc[all_d["concept"] == code1, "ancestors"].values[0])
58
+ sim_scores.append(result)
59
+ else:
60
+ try:
61
+ result = sim_df.loc[
62
+ (sim_df["Code1"] == code1) & (sim_df["Code2"] == code2), "score"
63
+ ].values[0]
64
+ sim_scores.append(result)
65
+ except:
66
+ result = sim_df.loc[
67
+ (sim_df["Code1"] == code2) & (sim_df["Code2"] == code1), "score"
68
+ ].values[0]
69
+ sim_scores.append(result)
70
+ if len(sim_scores) > 0:
71
+ return max(sim_scores)
72
+ else:
73
+ return 0
74
+
75
+
76
+ ##### SimCSE
77
+ def collate_simcse(batch, tokenizer):
78
+ """
79
+ Use the first sample in the batch as the anchor,
80
+ use the duplicate of anchor as the positive,
81
+ use the rest of the batch as negatives.
82
+ """
83
+ anchor = batch[0] # use the first sample in the batch as anchor
84
+ positive = anchor[:] # create a duplicate of anchor as positive
85
+ negatives = batch[1:] # everything else as negatives
86
+ df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
87
+
88
+ anchor_token = tokenizer.encode_plus(
89
+ anchor,
90
+ return_token_type_ids=False,
91
+ return_attention_mask=True,
92
+ return_tensors="pt",
93
+ )
94
+ anchor_row = pd.DataFrame(
95
+ {
96
+ "label": 0,
97
+ "input_ids": anchor_token["input_ids"].tolist(),
98
+ "attention_mask": anchor_token["attention_mask"].tolist(),
99
+ }
100
+ )
101
+ df = pd.concat([df, anchor_row])
102
+
103
+ pos_token = tokenizer.encode_plus(
104
+ positive,
105
+ return_token_type_ids=False,
106
+ return_attention_mask=True,
107
+ return_tensors="pt",
108
+ )
109
+ pos_row = pd.DataFrame(
110
+ {
111
+ "label": 1,
112
+ "input_ids": pos_token["input_ids"].tolist(),
113
+ "attention_mask": pos_token["attention_mask"].tolist(),
114
+ }
115
+ )
116
+ df = pd.concat([df, pos_row])
117
+
118
+ for neg in negatives:
119
+ neg_token = tokenizer.encode_plus(
120
+ neg,
121
+ return_token_type_ids=False,
122
+ return_attention_mask=True,
123
+ return_tensors="pt",
124
+ )
125
+ neg_row = pd.DataFrame(
126
+ {
127
+ "label": 2,
128
+ "input_ids": neg_token["input_ids"].tolist(),
129
+ "attention_mask": neg_token["attention_mask"].tolist(),
130
+ }
131
+ )
132
+ df = pd.concat([df, neg_row])
133
+
134
+ label = torch.tensor(df["label"].tolist())
135
+
136
+ input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
137
+ padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
138
+ padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
139
+
140
+ attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
141
+ padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
142
+ padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
143
+
144
+ return {
145
+ "label": label,
146
+ "input_ids": padded_input_ids,
147
+ "attention_mask": padded_attention_mask,
148
+ }
149
+
150
+
151
+ def create_dataloader_simcse(
152
+ dataset,
153
+ tokenizer,
154
+ shuffle,
155
+ ):
156
+ return DataLoader(
157
+ dataset,
158
+ batch_size=config.batch_size_simcse,
159
+ shuffle=shuffle,
160
+ num_workers=config.num_workers,
161
+ collate_fn=lambda batch: collate_simcse(
162
+ batch,
163
+ tokenizer,
164
+ ),
165
+ )
166
+
167
+
168
+ class ContrastiveLearningDataModule_simcse(pl.LightningDataModule):
169
+ def __init__(
170
+ self,
171
+ train_df,
172
+ val_df,
173
+ tokenizer,
174
+ ):
175
+ super().__init__()
176
+ self.train_df = train_df
177
+ self.val_df = val_df
178
+ self.tokenizer = tokenizer
179
+
180
+ def setup(self, stage=None):
181
+ self.train_dataset = ContrastiveLearningDataset(self.train_df)
182
+ self.val_dataset = ContrastiveLearningDataset(self.val_df)
183
+
184
+ def train_dataloader(self):
185
+ return create_dataloader_simcse(
186
+ self.train_dataset,
187
+ self.tokenizer,
188
+ shuffle=True,
189
+ )
190
+
191
+ def val_dataloader(self):
192
+ return create_dataloader_simcse(
193
+ self.val_dataset,
194
+ self.tokenizer,
195
+ shuffle=False,
196
+ )
197
+
198
+
199
+ ##### SimCSE_w
200
+ def collate_simcse_w(
201
+ batch,
202
+ current_df,
203
+ query_df,
204
+ tokenizer,
205
+ sim_df,
206
+ all_d,
207
+ ):
208
+ """
209
+ Anchor: 0
210
+ Positive: 1
211
+ Negative: 2
212
+ """
213
+ anchor = batch[0]
214
+ positive = anchor[:]
215
+ negatives = batch[1:]
216
+ df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
217
+
218
+ anchor_token = tokenizer.encode_plus(
219
+ anchor,
220
+ return_token_type_ids=False,
221
+ return_attention_mask=True,
222
+ return_tensors="pt",
223
+ )
224
+
225
+ anchor_row = pd.DataFrame(
226
+ {
227
+ "label": 0,
228
+ "input_ids": anchor_token["input_ids"].tolist(),
229
+ "attention_mask": anchor_token["attention_mask"].tolist(),
230
+ "score": 1,
231
+ }
232
+ )
233
+ df = pd.concat([df, anchor_row])
234
+
235
+ pos_token = tokenizer.encode_plus(
236
+ positive,
237
+ return_token_type_ids=False,
238
+ return_attention_mask=True,
239
+ return_tensors="pt",
240
+ )
241
+ pos_row = pd.DataFrame(
242
+ {
243
+ "label": 1,
244
+ "input_ids": pos_token["input_ids"].tolist(),
245
+ "attention_mask": pos_token["attention_mask"].tolist(),
246
+ "score": 1,
247
+ }
248
+ )
249
+ df = pd.concat([df, pos_row])
250
+
251
+ for neg in negatives:
252
+ neg_token = tokenizer.encode_plus(
253
+ neg,
254
+ return_token_type_ids=False,
255
+ return_attention_mask=True,
256
+ return_tensors="pt",
257
+ )
258
+ score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
259
+ offset = 8
260
+ score = score + offset
261
+ neg_row = pd.DataFrame(
262
+ {
263
+ "label": 2,
264
+ "input_ids": neg_token["input_ids"].tolist(),
265
+ "attention_mask": neg_token["attention_mask"].tolist(),
266
+ "score": score,
267
+ }
268
+ )
269
+ df = pd.concat([df, neg_row])
270
+
271
+ label = torch.tensor(df["label"].tolist())
272
+
273
+ input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
274
+ padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
275
+ padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
276
+
277
+ attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
278
+ padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
279
+ padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
280
+
281
+ score = torch.tensor(df["score"].tolist())
282
+
283
+ return {
284
+ "label": label,
285
+ "input_ids": padded_input_ids,
286
+ "attention_mask": padded_attention_mask,
287
+ "score": score,
288
+ }
289
+
290
+
291
+ def create_dataloader_simcse_w(
292
+ dataset,
293
+ current_df,
294
+ query_df,
295
+ tokenizer,
296
+ sim_df,
297
+ all_d,
298
+ shuffle,
299
+ ):
300
+ return DataLoader(
301
+ dataset,
302
+ batch_size=config.batch_size_simcse,
303
+ shuffle=shuffle,
304
+ num_workers=config.num_workers,
305
+ collate_fn=lambda batch: collate_simcse_w(
306
+ batch,
307
+ current_df,
308
+ query_df,
309
+ tokenizer,
310
+ sim_df,
311
+ all_d,
312
+ ),
313
+ )
314
+
315
+
316
+ class ContrastiveLearningDataModule_simcse_w(pl.LightningDataModule):
317
+ def __init__(
318
+ self,
319
+ train_df,
320
+ val_df,
321
+ query_df,
322
+ tokenizer,
323
+ sim_df,
324
+ all_d,
325
+ ):
326
+ super().__init__()
327
+ self.train_df = train_df
328
+ self.val_df = val_df
329
+ self.query_df = query_df
330
+ self.tokenizer = tokenizer
331
+ self.sim_df = sim_df
332
+ self.all_d = all_d
333
+
334
+ def setup(self, stage=None):
335
+ self.train_dataset = ContrastiveLearningDataset(self.train_df)
336
+ self.val_dataset = ContrastiveLearningDataset(self.val_df)
337
+
338
+ def train_dataloader(self):
339
+ return create_dataloader_simcse_w(
340
+ self.train_dataset,
341
+ self.train_df,
342
+ self.query_df,
343
+ self.tokenizer,
344
+ self.sim_df,
345
+ self.all_d,
346
+ shuffle=True,
347
+ )
348
+
349
+ def val_dataloader(self):
350
+ return create_dataloader_simcse_w(
351
+ self.val_dataset,
352
+ self.val_df,
353
+ self.query_df,
354
+ self.tokenizer,
355
+ self.sim_df,
356
+ self.all_d,
357
+ shuffle=False,
358
+ )
359
+
360
+
361
+ ##### Samp
362
+ def collate_samp(
363
+ sentence,
364
+ current_df,
365
+ query_df,
366
+ tokenizer,
367
+ dictionary,
368
+ sim_df,
369
+ ):
370
+
371
+ anchor = sentence[0]
372
+ positives = positive_generator(
373
+ anchor, current_df, query_df, dictionary, num_pos=config.num_pos
374
+ )
375
+ negatives = negative_generator(
376
+ anchor,
377
+ current_df,
378
+ query_df,
379
+ dictionary,
380
+ sim_df,
381
+ num_neg=config.num_neg,
382
+ )
383
+ df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
384
+ anchor_token = tokenizer.encode_plus(
385
+ anchor,
386
+ return_token_type_ids=False,
387
+ return_attention_mask=True,
388
+ return_tensors="pt",
389
+ )
390
+
391
+ anchor_row = pd.DataFrame(
392
+ {
393
+ "label": 0,
394
+ "input_ids": anchor_token["input_ids"].tolist(),
395
+ "attention_mask": anchor_token["attention_mask"].tolist(),
396
+ }
397
+ )
398
+ df = pd.concat([df, anchor_row])
399
+
400
+ for pos in positives:
401
+ token = tokenizer.encode_plus(
402
+ pos,
403
+ return_token_type_ids=False,
404
+ return_attention_mask=True,
405
+ return_tensors="pt",
406
+ )
407
+ row = pd.DataFrame(
408
+ {
409
+ "label": 1,
410
+ "input_ids": token["input_ids"].tolist(),
411
+ "attention_mask": token["attention_mask"].tolist(),
412
+ }
413
+ )
414
+ df = pd.concat([df, row])
415
+
416
+ for neg in negatives:
417
+ token = tokenizer.encode_plus(
418
+ neg,
419
+ return_token_type_ids=False,
420
+ return_attention_mask=True,
421
+ return_tensors="pt",
422
+ )
423
+ row = pd.DataFrame(
424
+ {
425
+ "label": 2,
426
+ "input_ids": token["input_ids"].tolist(),
427
+ "attention_mask": token["attention_mask"].tolist(),
428
+ }
429
+ )
430
+ df = pd.concat([df, row])
431
+
432
+ label = torch.tensor(df["label"].tolist())
433
+
434
+ input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
435
+ padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
436
+ padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
437
+
438
+ attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
439
+ padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
440
+ padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
441
+
442
+ return {
443
+ "label": label,
444
+ "input_ids": padded_input_ids,
445
+ "attention_mask": padded_attention_mask,
446
+ }
447
+
448
+
449
+ def create_dataloader_samp(
450
+ dataset,
451
+ current_df,
452
+ query_df,
453
+ tokenizer,
454
+ dictionary,
455
+ sim_df,
456
+ shuffle,
457
+ ):
458
+ return DataLoader(
459
+ dataset,
460
+ batch_size=config.batch_size,
461
+ shuffle=shuffle,
462
+ num_workers=config.num_workers,
463
+ collate_fn=lambda batch: collate_samp(
464
+ batch,
465
+ current_df,
466
+ query_df,
467
+ tokenizer,
468
+ dictionary,
469
+ sim_df,
470
+ ),
471
+ )
472
+
473
+
474
+ class ContrastiveLearningDataModule_samp(pl.LightningDataModule):
475
+ def __init__(
476
+ self,
477
+ train_df,
478
+ val_df,
479
+ query_df,
480
+ tokenizer,
481
+ dictionary,
482
+ sim_df,
483
+ ):
484
+ super().__init__()
485
+ self.train_df = train_df
486
+ self.val_df = val_df
487
+ self.query_df = query_df
488
+ self.tokenizer = tokenizer
489
+ self.dictionary = dictionary
490
+ self.sim_df = sim_df
491
+
492
+ def setup(self, stage=None):
493
+ self.train_dataset = ContrastiveLearningDataset(self.train_df)
494
+ self.val_dataset = ContrastiveLearningDataset(self.val_df)
495
+
496
+ def train_dataloader(self):
497
+ return create_dataloader_samp(
498
+ self.train_dataset,
499
+ self.train_df,
500
+ self.query_df,
501
+ self.tokenizer,
502
+ self.dictionary,
503
+ self.sim_df,
504
+ shuffle=True,
505
+ )
506
+
507
+ def val_dataloader(self):
508
+ return create_dataloader_samp(
509
+ self.val_dataset,
510
+ self.val_df,
511
+ self.query_df,
512
+ self.tokenizer,
513
+ self.dictionary,
514
+ self.sim_df,
515
+ shuffle=False,
516
+ )
517
+
518
+
519
+ ##### Samp_w
520
+ def collate_samp_w(
521
+ sentence,
522
+ current_df,
523
+ query_df,
524
+ tokenizer,
525
+ dictionary,
526
+ sim_df,
527
+ all_d,
528
+ ):
529
+ """
530
+ Anchor: 0
531
+ Positive: 1
532
+ Negative: 2
533
+ """
534
+ anchor = sentence[0]
535
+ positives = positive_generator(
536
+ anchor, current_df, query_df, dictionary, num_pos=config.num_pos
537
+ )
538
+ negatives = negative_generator(
539
+ anchor,
540
+ current_df,
541
+ query_df,
542
+ dictionary,
543
+ sim_df,
544
+ num_neg=config.num_neg,
545
+ )
546
+ df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
547
+ anchor_token = tokenizer.encode_plus(
548
+ anchor,
549
+ return_token_type_ids=False,
550
+ return_attention_mask=True,
551
+ return_tensors="pt",
552
+ )
553
+
554
+ anchor_row = pd.DataFrame(
555
+ {
556
+ "label": 0,
557
+ "input_ids": anchor_token["input_ids"].tolist(),
558
+ "attention_mask": anchor_token["attention_mask"].tolist(),
559
+ "score": 1,
560
+ }
561
+ )
562
+ df = pd.concat([df, anchor_row])
563
+
564
+ for pos in positives:
565
+ token = tokenizer.encode_plus(
566
+ pos,
567
+ return_token_type_ids=False,
568
+ return_attention_mask=True,
569
+ return_tensors="pt",
570
+ )
571
+ row = pd.DataFrame(
572
+ {
573
+ "label": 1,
574
+ "input_ids": token["input_ids"].tolist(),
575
+ "attention_mask": token["attention_mask"].tolist(),
576
+ "score": 1,
577
+ }
578
+ )
579
+ df = pd.concat([df, row])
580
+
581
+ for neg in negatives:
582
+ token = tokenizer.encode_plus(
583
+ neg,
584
+ return_token_type_ids=False,
585
+ return_attention_mask=True,
586
+ return_tensors="pt",
587
+ )
588
+ score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
589
+ offset = 8 # all negative scores start with 8 to distinguish from the positives
590
+ score = score + offset
591
+ row = pd.DataFrame(
592
+ {
593
+ "label": 2,
594
+ "input_ids": token["input_ids"].tolist(),
595
+ "attention_mask": token["attention_mask"].tolist(),
596
+ "score": score,
597
+ }
598
+ )
599
+ df = pd.concat([df, row])
600
+
601
+ label = torch.tensor(df["label"].tolist())
602
+
603
+ input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
604
+ padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
605
+ padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
606
+
607
+ attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
608
+ padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
609
+ padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
610
+
611
+ score = torch.tensor(df["score"].tolist())
612
+
613
+ return {
614
+ "label": label,
615
+ "input_ids": padded_input_ids,
616
+ "attention_mask": padded_attention_mask,
617
+ "score": score,
618
+ }
619
+
620
+
621
+ def create_dataloader_samp_w(
622
+ dataset,
623
+ current_df,
624
+ query_df,
625
+ tokenizer,
626
+ dictionary,
627
+ sim_df,
628
+ all_d,
629
+ shuffle,
630
+ ):
631
+ return DataLoader(
632
+ dataset,
633
+ batch_size=config.batch_size,
634
+ shuffle=shuffle,
635
+ num_workers=config.num_workers,
636
+ collate_fn=lambda batch: collate_samp_w(
637
+ batch,
638
+ current_df,
639
+ query_df,
640
+ tokenizer,
641
+ dictionary,
642
+ sim_df,
643
+ all_d,
644
+ ),
645
+ )
646
+
647
+
648
+ class ContrastiveLearningDataModule_samp_w(pl.LightningDataModule):
649
+ def __init__(
650
+ self,
651
+ train_df,
652
+ val_df,
653
+ query_df,
654
+ tokenizer,
655
+ dictionary,
656
+ sim_df,
657
+ all_d,
658
+ ):
659
+ super().__init__()
660
+ self.train_df = train_df
661
+ self.val_df = val_df
662
+ self.query_df = query_df
663
+ self.tokenizer = tokenizer
664
+ self.dictionary = dictionary
665
+ self.sim_df = sim_df
666
+ self.all_d = all_d
667
+
668
+ def setup(self, stage=None):
669
+ self.train_dataset = ContrastiveLearningDataset(self.train_df)
670
+ self.val_dataset = ContrastiveLearningDataset(self.val_df)
671
+
672
+ def train_dataloader(self):
673
+ return create_dataloader_samp_w(
674
+ self.train_dataset,
675
+ self.train_df,
676
+ self.query_df,
677
+ self.tokenizer,
678
+ self.dictionary,
679
+ self.sim_df,
680
+ self.all_d,
681
+ shuffle=True,
682
+ )
683
+
684
+ def val_dataloader(self):
685
+ return create_dataloader_samp_w(
686
+ self.val_dataset,
687
+ self.val_df,
688
+ self.query_df,
689
+ self.tokenizer,
690
+ self.dictionary,
691
+ self.sim_df,
692
+ self.all_d,
693
+ shuffle=False,
694
+ )
695
+
696
+
697
+ #### Test
698
+ from transformers import AutoTokenizer
699
+ from ast import literal_eval
700
+ from sklearn.model_selection import train_test_split
701
+
702
+ query_df = pd.read_csv(
703
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv"
704
+ )
705
+ query_df["concepts"] = query_df["concepts"].apply(literal_eval)
706
+ query_df["codes"] = query_df["codes"].apply(literal_eval)
707
+ query_df["codes"] = query_df["codes"].apply(
708
+ lambda x: [val for val in x if val is not None]
709
+ ) # remove None in lists
710
+ query_df = query_df.drop(columns=["one_hot"])
711
+ train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
712
+
713
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
714
+
715
+ sim_df = pd.read_csv(
716
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv"
717
+ )
718
+
719
+ all_d = pd.read_csv(
720
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv"
721
+ )
722
+ all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
723
+ all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
724
+ dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
725
+
726
+ d1 = ContrastiveLearningDataModule_simcse(train_df, val_df, tokenizer)
727
+ d1.setup()
728
+ train_d1 = d1.train_dataloader()
729
+ for batch in train_d1:
730
+ b1 = batch
731
+ break
732
+
733
+ d2 = ContrastiveLearningDataModule_simcse_w(
734
+ train_df, val_df, query_df, tokenizer, sim_df, all_d
735
+ )
736
+ d2.setup()
737
+ train_d2 = d2.train_dataloader()
738
+ for batch in train_d2:
739
+ b2 = batch
740
+ break
741
+
742
+ d3 = ContrastiveLearningDataModule_samp(
743
+ train_df, val_df, query_df, tokenizer, dictionary, sim_df
744
+ )
745
+ d3.setup()
746
+ train_d3 = d3.train_dataloader()
747
+ for batch in train_d3:
748
+ b3 = batch
749
+ break
750
+
751
+ d4 = ContrastiveLearningDataModule_samp_w(
752
+ train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d
753
+ )
754
+ d4.setup()
755
+ train_d4 = d4.train_dataloader()
756
+ for batch in train_d4:
757
+ b4 = batch
758
+ break