CHOPT / dataset.py
sxtforreal's picture
Upload 3 files
d09e211 verified
raw
history blame
No virus
6.58 kB
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import lightning.pytorch as pl
import config
import pandas as pd
import copy
from ast import literal_eval
from sklearn.model_selection import train_test_split
import sys
sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag")
from data_proc.data_gen import (
positive_generator,
positive_generator_alter,
negative_generator,
negative_generator_alter,
negative_generator_random,
negative_generator_v2,
get_mentioned_code,
)
def tokenize(text, tokenizer, tag):
inputs = tokenizer(
text,
return_token_type_ids=False,
return_tensors="pt",
)
inputs["input_ids"] = inputs["input_ids"][0]
inputs["attention_mask"] = inputs["attention_mask"][0]
inputs["mlm_ids"] = copy.deepcopy(inputs["input_ids"])
inputs["mlm_labels"] = copy.deepcopy(inputs["input_ids"])
tokens_to_ignore = torch.tensor([101, 102, 0]) # [CLS], [SEP], [PAD]
valid_tokens = inputs["input_ids"][
~torch.isin(inputs["input_ids"], tokens_to_ignore)
]
num_of_token_to_mask = int(len(valid_tokens) * config.mask_pct)
token_to_mask = valid_tokens[
torch.randperm(valid_tokens.size(0))[:num_of_token_to_mask]
]
inputs["mlm_ids"] = [
103 if x in token_to_mask else x for x in inputs["mlm_ids"]
] # [MASK]
inputs["mlm_labels"] = [
y if y in token_to_mask else -100 for y in inputs["mlm_labels"]
]
inputs["mlm_ids"] = torch.tensor(inputs["mlm_ids"])
inputs["mlm_labels"] = torch.tensor(inputs["mlm_labels"])
if tag == "A":
inputs["tag"] = 0
elif tag == "P":
inputs["tag"] = 1
elif tag == "N":
inputs["tag"] = 2
return inputs
class CLDataset(Dataset):
def __init__(
self,
data: pd.DataFrame,
):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_row = self.data.iloc[index]
sentence = data_row.sentences
return sentence
def collate_func(batch, tokenizer, current_df, query_df, dictionary, all_d):
anchor = batch[0]
positives = positive_generator_alter(
anchor,
current_df,
dictionary,
num_pos=config.num_pos,
)
negatives = negative_generator_v2(
anchor,
current_df,
query_df,
all_d,
num_neg=config.num_neg,
)
inputs = []
anchor_dict = tokenize(anchor, tokenizer, "A")
inputs.append(anchor_dict)
for pos in positives:
pos_dict = tokenize(pos, tokenizer, "P")
inputs.append(pos_dict)
for neg in negatives:
neg_dict = tokenize(neg, tokenizer, "N")
inputs.append(neg_dict)
tags = torch.tensor([d["tag"] for d in inputs])
input_ids_tsr = [d["input_ids"] for d in inputs]
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=0)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
attention_mask_tsr = [d["attention_mask"] for d in inputs]
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
mlm_ids_tsr = [d["mlm_ids"] for d in inputs]
padded_mlm_ids = pad_sequence(mlm_ids_tsr, padding_value=0)
padded_mlm_ids = torch.transpose(padded_mlm_ids, 0, 1)
mlm_labels_tsr = [d["mlm_labels"] for d in inputs]
padded_mlm_labels = pad_sequence(mlm_labels_tsr, padding_value=-100)
padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1)
return {
"tags": tags,
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
"mlm_ids": padded_mlm_ids,
"mlm_labels": padded_mlm_labels,
}
def create_dataloader(
dataset, tokenizer, shuffle, current_df, query_df, dictionary, all_d
):
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=1,
collate_fn=lambda batch: collate_func(
batch, tokenizer, current_df, query_df, dictionary, all_d
),
)
class CLDataModule(pl.LightningDataModule):
def __init__(
self,
train_df,
val_df,
tokenizer,
query_df,
dictionary,
all_d,
):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.tokenizer = tokenizer
self.query_df = query_df
self.dictionary = dictionary
self.all_d = all_d
def setup(self, stage=None):
self.train_dataset = CLDataset(self.train_df)
self.val_dataset = CLDataset(self.val_df)
def train_dataloader(self):
return create_dataloader(
self.train_dataset,
self.tokenizer,
shuffle=True,
current_df=self.train_df,
query_df=self.query_df,
dictionary=self.dictionary,
all_d=self.all_d,
)
def val_dataloader(self):
return create_dataloader(
self.val_dataset,
self.tokenizer,
shuffle=False,
current_df=self.val_df,
query_df=self.query_df,
dictionary=self.dictionary,
all_d=self.all_d,
)
if __name__ == "__main__":
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
)
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
all_d["finding_sites"] = all_d["finding_sites"].apply(literal_eval)
all_d["morphology"] = all_d["morphology"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
d = CLDataModule(train_df, val_df, tokenizer, query_df, dictionary, all_d)
d.setup()
train = d.train_dataloader()
for batch in train:
b = batch
break