import torch from transformers import AutoTokenizer from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import lightning.pytorch as pl import config import pandas as pd import copy from ast import literal_eval from sklearn.model_selection import train_test_split import sys sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag") from data_proc.data_gen import ( positive_generator, positive_generator_alter, negative_generator, negative_generator_alter, negative_generator_random, negative_generator_v2, get_mentioned_code, ) def tokenize(text, tokenizer, tag): inputs = tokenizer( text, return_token_type_ids=False, return_tensors="pt", ) inputs["input_ids"] = inputs["input_ids"][0] inputs["attention_mask"] = inputs["attention_mask"][0] inputs["mlm_ids"] = copy.deepcopy(inputs["input_ids"]) inputs["mlm_labels"] = copy.deepcopy(inputs["input_ids"]) tokens_to_ignore = torch.tensor([101, 102, 0]) # [CLS], [SEP], [PAD] valid_tokens = inputs["input_ids"][ ~torch.isin(inputs["input_ids"], tokens_to_ignore) ] num_of_token_to_mask = int(len(valid_tokens) * config.mask_pct) token_to_mask = valid_tokens[ torch.randperm(valid_tokens.size(0))[:num_of_token_to_mask] ] inputs["mlm_ids"] = [ 103 if x in token_to_mask else x for x in inputs["mlm_ids"] ] # [MASK] inputs["mlm_labels"] = [ y if y in token_to_mask else -100 for y in inputs["mlm_labels"] ] inputs["mlm_ids"] = torch.tensor(inputs["mlm_ids"]) inputs["mlm_labels"] = torch.tensor(inputs["mlm_labels"]) if tag == "A": inputs["tag"] = 0 elif tag == "P": inputs["tag"] = 1 elif tag == "N": inputs["tag"] = 2 return inputs class CLDataset(Dataset): def __init__( self, data: pd.DataFrame, ): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): data_row = self.data.iloc[index] sentence = data_row.sentences return sentence def collate_func(batch, tokenizer, current_df, query_df, dictionary, all_d): anchor = batch[0] positives = positive_generator_alter( anchor, current_df, dictionary, num_pos=config.num_pos, ) negatives = negative_generator_v2( anchor, current_df, query_df, all_d, num_neg=config.num_neg, ) inputs = [] anchor_dict = tokenize(anchor, tokenizer, "A") inputs.append(anchor_dict) for pos in positives: pos_dict = tokenize(pos, tokenizer, "P") inputs.append(pos_dict) for neg in negatives: neg_dict = tokenize(neg, tokenizer, "N") inputs.append(neg_dict) tags = torch.tensor([d["tag"] for d in inputs]) input_ids_tsr = [d["input_ids"] for d in inputs] padded_input_ids = pad_sequence(input_ids_tsr, padding_value=0) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) attention_mask_tsr = [d["attention_mask"] for d in inputs] padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) mlm_ids_tsr = [d["mlm_ids"] for d in inputs] padded_mlm_ids = pad_sequence(mlm_ids_tsr, padding_value=0) padded_mlm_ids = torch.transpose(padded_mlm_ids, 0, 1) mlm_labels_tsr = [d["mlm_labels"] for d in inputs] padded_mlm_labels = pad_sequence(mlm_labels_tsr, padding_value=-100) padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1) return { "tags": tags, "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, "mlm_ids": padded_mlm_ids, "mlm_labels": padded_mlm_labels, } def create_dataloader( dataset, tokenizer, shuffle, current_df, query_df, dictionary, all_d ): return DataLoader( dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=1, collate_fn=lambda batch: collate_func( batch, tokenizer, current_df, query_df, dictionary, all_d ), ) class CLDataModule(pl.LightningDataModule): def __init__( self, train_df, val_df, tokenizer, query_df, dictionary, all_d, ): super().__init__() self.train_df = train_df self.val_df = val_df self.tokenizer = tokenizer self.query_df = query_df self.dictionary = dictionary self.all_d = all_d def setup(self, stage=None): self.train_dataset = CLDataset(self.train_df) self.val_dataset = CLDataset(self.val_df) def train_dataloader(self): return create_dataloader( self.train_dataset, self.tokenizer, shuffle=True, current_df=self.train_df, query_df=self.query_df, dictionary=self.dictionary, all_d=self.all_d, ) def val_dataloader(self): return create_dataloader( self.val_dataset, self.tokenizer, shuffle=False, current_df=self.val_df, query_df=self.query_df, dictionary=self.dictionary, all_d=self.all_d, ) if __name__ == "__main__": query_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" ) query_df["concepts"] = query_df["concepts"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply( lambda x: [val for val in x if val is not None] ) train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") all_d = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" ) all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) all_d["finding_sites"] = all_d["finding_sites"].apply(literal_eval) all_d["morphology"] = all_d["morphology"].apply(literal_eval) dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) d = CLDataModule(train_df, val_df, tokenizer, query_df, dictionary, all_d) d.setup() train = d.train_dataloader() for batch in train: b = batch break