CHOPT / dataset.py
sxtforreal's picture
Create dataset.py
504db9e verified
raw
history blame
No virus
20.8 kB
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import lightning.pytorch as pl
import config
import sys
sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag")
from data_proc.data_gen import (
positive_generator,
negative_generator,
get_mentioned_code,
)
##### General
class ContrastiveLearningDataset(Dataset):
def __init__(
self,
data: pd.DataFrame,
):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_row = self.data.iloc[index]
sentence = data_row.sentences
return sentence
def max_pairwise_sim(sentence1, sentence2, current_df, query_df, sim_df, all_d):
"""Returns the maximum ontology similarity score between concept pairs mentioned in sentence1 and sentence2.
Args:
sentence1: anchor sentence
sentence2: negative sentence
current_df: the dataset where anchor sentence stays
query_df: the union of training and validation sets
dictionary: cardiac-related {concepts: synonyms}
sim_df: the dataset of pairwise ontology similarity score
all_d: the dataset of [concepts, synonyms, list of ancestor concepts]
"""
# retrieve concepts from the two sentences
anchor_codes = get_mentioned_code(sentence1, current_df)
other_codes = get_mentioned_code(sentence2, query_df)
# create snomed-ct code pairs and calculate the score using sim_df
code_pairs = list(zip(anchor_codes, other_codes))
sim_scores = []
for pair in code_pairs:
code1 = pair[0]
code2 = pair[1]
if code1 == code2:
result = len(all_d.loc[all_d["concept"] == code1, "ancestors"].values[0])
sim_scores.append(result)
else:
try:
result = sim_df.loc[
(sim_df["Code1"] == code1) & (sim_df["Code2"] == code2), "score"
].values[0]
sim_scores.append(result)
except:
result = sim_df.loc[
(sim_df["Code1"] == code2) & (sim_df["Code2"] == code1), "score"
].values[0]
sim_scores.append(result)
if len(sim_scores) > 0:
return max(sim_scores)
else:
return 0
##### SimCSE
def collate_simcse(batch, tokenizer):
"""
Use the first sample in the batch as the anchor,
use the duplicate of anchor as the positive,
use the rest of the batch as negatives.
"""
anchor = batch[0] # use the first sample in the batch as anchor
positive = anchor[:] # create a duplicate of anchor as positive
negatives = batch[1:] # everything else as negatives
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
anchor_token = tokenizer.encode_plus(
anchor,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
anchor_row = pd.DataFrame(
{
"label": 0,
"input_ids": anchor_token["input_ids"].tolist(),
"attention_mask": anchor_token["attention_mask"].tolist(),
}
)
df = pd.concat([df, anchor_row])
pos_token = tokenizer.encode_plus(
positive,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
pos_row = pd.DataFrame(
{
"label": 1,
"input_ids": pos_token["input_ids"].tolist(),
"attention_mask": pos_token["attention_mask"].tolist(),
}
)
df = pd.concat([df, pos_row])
for neg in negatives:
neg_token = tokenizer.encode_plus(
neg,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
neg_row = pd.DataFrame(
{
"label": 2,
"input_ids": neg_token["input_ids"].tolist(),
"attention_mask": neg_token["attention_mask"].tolist(),
}
)
df = pd.concat([df, neg_row])
label = torch.tensor(df["label"].tolist())
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
return {
"label": label,
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
}
def create_dataloader_simcse(
dataset,
tokenizer,
shuffle,
):
return DataLoader(
dataset,
batch_size=config.batch_size_simcse,
shuffle=shuffle,
num_workers=config.num_workers,
collate_fn=lambda batch: collate_simcse(
batch,
tokenizer,
),
)
class ContrastiveLearningDataModule_simcse(pl.LightningDataModule):
def __init__(
self,
train_df,
val_df,
tokenizer,
):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.tokenizer = tokenizer
def setup(self, stage=None):
self.train_dataset = ContrastiveLearningDataset(self.train_df)
self.val_dataset = ContrastiveLearningDataset(self.val_df)
def train_dataloader(self):
return create_dataloader_simcse(
self.train_dataset,
self.tokenizer,
shuffle=True,
)
def val_dataloader(self):
return create_dataloader_simcse(
self.val_dataset,
self.tokenizer,
shuffle=False,
)
##### SimCSE_w
def collate_simcse_w(
batch,
current_df,
query_df,
tokenizer,
sim_df,
all_d,
):
"""
Anchor: 0
Positive: 1
Negative: 2
"""
anchor = batch[0]
positive = anchor[:]
negatives = batch[1:]
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
anchor_token = tokenizer.encode_plus(
anchor,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
anchor_row = pd.DataFrame(
{
"label": 0,
"input_ids": anchor_token["input_ids"].tolist(),
"attention_mask": anchor_token["attention_mask"].tolist(),
"score": 1,
}
)
df = pd.concat([df, anchor_row])
pos_token = tokenizer.encode_plus(
positive,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
pos_row = pd.DataFrame(
{
"label": 1,
"input_ids": pos_token["input_ids"].tolist(),
"attention_mask": pos_token["attention_mask"].tolist(),
"score": 1,
}
)
df = pd.concat([df, pos_row])
for neg in negatives:
neg_token = tokenizer.encode_plus(
neg,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
offset = 8
score = score + offset
neg_row = pd.DataFrame(
{
"label": 2,
"input_ids": neg_token["input_ids"].tolist(),
"attention_mask": neg_token["attention_mask"].tolist(),
"score": score,
}
)
df = pd.concat([df, neg_row])
label = torch.tensor(df["label"].tolist())
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
score = torch.tensor(df["score"].tolist())
return {
"label": label,
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
"score": score,
}
def create_dataloader_simcse_w(
dataset,
current_df,
query_df,
tokenizer,
sim_df,
all_d,
shuffle,
):
return DataLoader(
dataset,
batch_size=config.batch_size_simcse,
shuffle=shuffle,
num_workers=config.num_workers,
collate_fn=lambda batch: collate_simcse_w(
batch,
current_df,
query_df,
tokenizer,
sim_df,
all_d,
),
)
class ContrastiveLearningDataModule_simcse_w(pl.LightningDataModule):
def __init__(
self,
train_df,
val_df,
query_df,
tokenizer,
sim_df,
all_d,
):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.query_df = query_df
self.tokenizer = tokenizer
self.sim_df = sim_df
self.all_d = all_d
def setup(self, stage=None):
self.train_dataset = ContrastiveLearningDataset(self.train_df)
self.val_dataset = ContrastiveLearningDataset(self.val_df)
def train_dataloader(self):
return create_dataloader_simcse_w(
self.train_dataset,
self.train_df,
self.query_df,
self.tokenizer,
self.sim_df,
self.all_d,
shuffle=True,
)
def val_dataloader(self):
return create_dataloader_simcse_w(
self.val_dataset,
self.val_df,
self.query_df,
self.tokenizer,
self.sim_df,
self.all_d,
shuffle=False,
)
##### Samp
def collate_samp(
sentence,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
):
anchor = sentence[0]
positives = positive_generator(
anchor, current_df, query_df, dictionary, num_pos=config.num_pos
)
negatives = negative_generator(
anchor,
current_df,
query_df,
dictionary,
sim_df,
num_neg=config.num_neg,
)
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"])
anchor_token = tokenizer.encode_plus(
anchor,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
anchor_row = pd.DataFrame(
{
"label": 0,
"input_ids": anchor_token["input_ids"].tolist(),
"attention_mask": anchor_token["attention_mask"].tolist(),
}
)
df = pd.concat([df, anchor_row])
for pos in positives:
token = tokenizer.encode_plus(
pos,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
row = pd.DataFrame(
{
"label": 1,
"input_ids": token["input_ids"].tolist(),
"attention_mask": token["attention_mask"].tolist(),
}
)
df = pd.concat([df, row])
for neg in negatives:
token = tokenizer.encode_plus(
neg,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
row = pd.DataFrame(
{
"label": 2,
"input_ids": token["input_ids"].tolist(),
"attention_mask": token["attention_mask"].tolist(),
}
)
df = pd.concat([df, row])
label = torch.tensor(df["label"].tolist())
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
return {
"label": label,
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
}
def create_dataloader_samp(
dataset,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
shuffle,
):
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=config.num_workers,
collate_fn=lambda batch: collate_samp(
batch,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
),
)
class ContrastiveLearningDataModule_samp(pl.LightningDataModule):
def __init__(
self,
train_df,
val_df,
query_df,
tokenizer,
dictionary,
sim_df,
):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.query_df = query_df
self.tokenizer = tokenizer
self.dictionary = dictionary
self.sim_df = sim_df
def setup(self, stage=None):
self.train_dataset = ContrastiveLearningDataset(self.train_df)
self.val_dataset = ContrastiveLearningDataset(self.val_df)
def train_dataloader(self):
return create_dataloader_samp(
self.train_dataset,
self.train_df,
self.query_df,
self.tokenizer,
self.dictionary,
self.sim_df,
shuffle=True,
)
def val_dataloader(self):
return create_dataloader_samp(
self.val_dataset,
self.val_df,
self.query_df,
self.tokenizer,
self.dictionary,
self.sim_df,
shuffle=False,
)
##### Samp_w
def collate_samp_w(
sentence,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
all_d,
):
"""
Anchor: 0
Positive: 1
Negative: 2
"""
anchor = sentence[0]
positives = positive_generator(
anchor, current_df, query_df, dictionary, num_pos=config.num_pos
)
negatives = negative_generator(
anchor,
current_df,
query_df,
dictionary,
sim_df,
num_neg=config.num_neg,
)
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"])
anchor_token = tokenizer.encode_plus(
anchor,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
anchor_row = pd.DataFrame(
{
"label": 0,
"input_ids": anchor_token["input_ids"].tolist(),
"attention_mask": anchor_token["attention_mask"].tolist(),
"score": 1,
}
)
df = pd.concat([df, anchor_row])
for pos in positives:
token = tokenizer.encode_plus(
pos,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
row = pd.DataFrame(
{
"label": 1,
"input_ids": token["input_ids"].tolist(),
"attention_mask": token["attention_mask"].tolist(),
"score": 1,
}
)
df = pd.concat([df, row])
for neg in negatives:
token = tokenizer.encode_plus(
neg,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d)
offset = 8 # all negative scores start with 8 to distinguish from the positives
score = score + offset
row = pd.DataFrame(
{
"label": 2,
"input_ids": token["input_ids"].tolist(),
"attention_mask": token["attention_mask"].tolist(),
"score": score,
}
)
df = pd.concat([df, row])
label = torch.tensor(df["label"].tolist())
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"]))
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"]))
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
score = torch.tensor(df["score"].tolist())
return {
"label": label,
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
"score": score,
}
def create_dataloader_samp_w(
dataset,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
all_d,
shuffle,
):
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=config.num_workers,
collate_fn=lambda batch: collate_samp_w(
batch,
current_df,
query_df,
tokenizer,
dictionary,
sim_df,
all_d,
),
)
class ContrastiveLearningDataModule_samp_w(pl.LightningDataModule):
def __init__(
self,
train_df,
val_df,
query_df,
tokenizer,
dictionary,
sim_df,
all_d,
):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.query_df = query_df
self.tokenizer = tokenizer
self.dictionary = dictionary
self.sim_df = sim_df
self.all_d = all_d
def setup(self, stage=None):
self.train_dataset = ContrastiveLearningDataset(self.train_df)
self.val_dataset = ContrastiveLearningDataset(self.val_df)
def train_dataloader(self):
return create_dataloader_samp_w(
self.train_dataset,
self.train_df,
self.query_df,
self.tokenizer,
self.dictionary,
self.sim_df,
self.all_d,
shuffle=True,
)
def val_dataloader(self):
return create_dataloader_samp_w(
self.val_dataset,
self.val_df,
self.query_df,
self.tokenizer,
self.dictionary,
self.sim_df,
self.all_d,
shuffle=False,
)
#### Test
from transformers import AutoTokenizer
from ast import literal_eval
from sklearn.model_selection import train_test_split
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv"
)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
) # remove None in lists
query_df = query_df.drop(columns=["one_hot"])
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
sim_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv"
)
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv"
)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
d1 = ContrastiveLearningDataModule_simcse(train_df, val_df, tokenizer)
d1.setup()
train_d1 = d1.train_dataloader()
for batch in train_d1:
b1 = batch
break
d2 = ContrastiveLearningDataModule_simcse_w(
train_df, val_df, query_df, tokenizer, sim_df, all_d
)
d2.setup()
train_d2 = d2.train_dataloader()
for batch in train_d2:
b2 = batch
break
d3 = ContrastiveLearningDataModule_samp(
train_df, val_df, query_df, tokenizer, dictionary, sim_df
)
d3.setup()
train_d3 = d3.train_dataloader()
for batch in train_d3:
b3 = batch
break
d4 = ContrastiveLearningDataModule_samp_w(
train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d
)
d4.setup()
train_d4 = d4.train_dataloader()
for batch in train_d4:
b4 = batch
break