""" This file contains functions for loading various needed data """ import json import torch import random import logging import os from random import random as rand from torch.utils.data import Dataset from torch.utils.data import DataLoader logger = logging.getLogger(__name__) local_file = os.path.split(__file__)[-1] logging.basicConfig( format='%(asctime)s : %(filename)s : %(funcName)s : %(levelname)s : %(message)s', level=logging.INFO) def load_acronym_kb(kb_path='acronym_kb.json'): f = open(kb_path, encoding='utf8') acronym_kb = json.load(f) for key, values in acronym_kb.items(): values = [v for v, s in values] acronym_kb[key] = values logger.info('loaded acronym dictionary successfully, in total there are [{a}] acronyms'.format(a=len(acronym_kb))) return acronym_kb def get_candidate(acronym_kb, short_term, can_num=10): return acronym_kb[short_term][:can_num] def load_data(path): data = list() for line in open(path, encoding='utf8'): row = json.loads(line) data.append(row) return data def load_dataset(data_path): all_short_term, all_long_term, all_context = list(), list(), list() for line in open(data_path, encoding='utf8'): obj = json.loads(line) short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) all_short_term.append(short_term) all_long_term.append(long_term) all_context.append(context) return {'short_term': all_short_term, 'long_term': all_long_term, 'context':all_context} def load_pretrain(data_path): all_short_term, all_long_term, all_context = list(), list(), list() cnt = 0 for line in open(data_path, encoding='utf8'): cnt += 1 # row = line.strip().split('\t') # if len(row) != 3:continue if cnt>200:continue obj = json.loads(line) short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) all_short_term.append(short_term) all_long_term.append(long_term) all_context.append(context) return {'short_term': all_short_term, 'long_term': all_long_term, 'context': all_context} class TextData(Dataset): def __init__(self, data): self.all_short_term = data['short_term'] self.all_long_term = data['long_term'] self.all_context = data['context'] def __len__(self): return len(self.all_short_term) def __getitem__(self, idx): return self.all_short_term[idx], self.all_long_term[idx], self.all_context[idx] def random_negative(target, elements): flag, result = True, '' while flag: temp = random.choice(elements) if temp != target: result = temp flag = False return result class SimpleLoader(): def __init__(self, batch_size, tokenizer, kb, shuffle=True): self.batch_size = batch_size self.shuffle = shuffle self.tokenizer = tokenizer self.kb = kb def collate_fn(self, batch_data): pos_tag, neg_tag = 0, 1 batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) batch_short_term, batch_long_term, batch_context = list(batch_short_term), list(batch_long_term), list(batch_context) batch_negative, batch_label, batch_label_neg = list(), list(), list() for index in range(len(batch_short_term)): short_term, long_term, context = batch_short_term[index], batch_long_term[index], batch_context[index] batch_label.append(pos_tag) candidates = [v[0] for v in self.kb[short_term]] if len(candidates) == 1: batch_negative.append(long_term) batch_label_neg.append(pos_tag) continue negative = random_negative(long_term, candidates) batch_negative.append(negative) batch_label_neg.append(neg_tag) prompt = batch_context + batch_context long_terms = batch_long_term + batch_negative label = batch_label + batch_label_neg encoding = self.tokenizer(prompt, long_terms, return_tensors="pt", padding=True, truncation=True) label = torch.LongTensor(label) return encoding, label def __call__(self, data_path): dataset = load_dataset(data_path=data_path) dataset = TextData(dataset) train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // 2, shuffle=self.shuffle, collate_fn=self.collate_fn) return train_iterator def mask_subword(subword_sequences, prob=0.15, masked_prob=0.8, VOCAB_SIZE=30522): PAD, CLS, SEP, MASK, BLANK = 0, 101, 102, 103, -100 masked_labels = list() for sentence in subword_sequences: labels = [BLANK for _ in range(len(sentence))] original = sentence[:] end = len(sentence) if PAD in sentence: end = sentence.index(PAD) for pos in range(end): if sentence[pos] in (CLS, SEP): continue if rand() > prob: continue if rand() < masked_prob: # 80% sentence[pos] = MASK elif rand() < 0.5: # 10% sentence[pos] = random.randint(0, VOCAB_SIZE-1) labels[pos] = original[pos] masked_labels.append(labels) return subword_sequences, masked_labels class AcroBERTLoader(): def __init__(self, batch_size, tokenizer, kb, shuffle=True, masked_prob=0.15, hard_num=2): self.batch_size = batch_size self.shuffle = shuffle self.tokenizer = tokenizer self.masked_prob = masked_prob self.hard_num = hard_num self.kb = kb self.all_long_terms = list() for vs in self.kb.values(): self.all_long_terms.extend(list(vs)) def select_negative(self, target): selected, flag, max_time = None, True, 10 if target in self.kb: long_term_candidates = self.kb[target] if len(long_term_candidates) == 1: long_term_candidates = self.all_long_terms else: long_term_candidates = self.all_long_terms attempt = 0 while flag and attempt < max_time: attempt += 1 selected = random.choice(long_term_candidates) if selected != target: flag = False if attempt == max_time: selected = random.choice(self.all_long_terms) return selected def collate_fn(self, batch_data): batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) pos_samples, neg_samples, masked_pos_samples = list(), list(), list() for _ in range(self.hard_num): temp_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] neg_long_terms = [self.select_negative(st) for st in batch_short_term] temp_neg_samples = [neg_long_terms[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] temp_masked_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] pos_samples.extend(temp_pos_samples) neg_samples.extend(temp_neg_samples) masked_pos_samples.extend(temp_masked_pos_samples) return pos_samples, masked_pos_samples, neg_samples def __call__(self, data_path): dataset = load_pretrain(data_path=data_path) logger.info('loaded dataset, sample = {a}'.format(a=len(dataset['short_term']))) dataset = TextData(dataset) train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // (2 * self.hard_num), shuffle=self.shuffle, collate_fn=self.collate_fn) return train_iterator