Spaces:
Runtime error
Runtime error
""" | |
This file contains functions for loading various needed data | |
""" | |
import json | |
import torch | |
import random | |
import logging | |
import os | |
from random import random as rand | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
logger = logging.getLogger(__name__) | |
local_file = os.path.split(__file__)[-1] | |
logging.basicConfig( | |
format='%(asctime)s : %(filename)s : %(funcName)s : %(levelname)s : %(message)s', | |
level=logging.INFO) | |
def load_acronym_kb(kb_path='acronym_kb.json'): | |
f = open(kb_path, encoding='utf8') | |
acronym_kb = json.load(f) | |
for key, values in acronym_kb.items(): | |
values = [v for v, s in values] | |
acronym_kb[key] = values | |
logger.info('loaded acronym dictionary successfully, in total there are [{a}] acronyms'.format(a=len(acronym_kb))) | |
return acronym_kb | |
def get_candidate(acronym_kb, short_term, can_num=10): | |
return acronym_kb[short_term][:can_num] | |
def load_data(path): | |
data = list() | |
for line in open(path, encoding='utf8'): | |
row = json.loads(line) | |
data.append(row) | |
return data | |
def load_dataset(data_path): | |
all_short_term, all_long_term, all_context = list(), list(), list() | |
for line in open(data_path, encoding='utf8'): | |
obj = json.loads(line) | |
short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) | |
all_short_term.append(short_term) | |
all_long_term.append(long_term) | |
all_context.append(context) | |
return {'short_term': all_short_term, 'long_term': all_long_term, 'context':all_context} | |
def load_pretrain(data_path): | |
all_short_term, all_long_term, all_context = list(), list(), list() | |
cnt = 0 | |
for line in open(data_path, encoding='utf8'): | |
cnt += 1 | |
# row = line.strip().split('\t') | |
# if len(row) != 3:continue | |
if cnt>200:continue | |
obj = json.loads(line) | |
short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) | |
all_short_term.append(short_term) | |
all_long_term.append(long_term) | |
all_context.append(context) | |
return {'short_term': all_short_term, 'long_term': all_long_term, 'context': all_context} | |
class TextData(Dataset): | |
def __init__(self, data): | |
self.all_short_term = data['short_term'] | |
self.all_long_term = data['long_term'] | |
self.all_context = data['context'] | |
def __len__(self): | |
return len(self.all_short_term) | |
def __getitem__(self, idx): | |
return self.all_short_term[idx], self.all_long_term[idx], self.all_context[idx] | |
def random_negative(target, elements): | |
flag, result = True, '' | |
while flag: | |
temp = random.choice(elements) | |
if temp != target: | |
result = temp | |
flag = False | |
return result | |
class SimpleLoader(): | |
def __init__(self, batch_size, tokenizer, kb, shuffle=True): | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.tokenizer = tokenizer | |
self.kb = kb | |
def collate_fn(self, batch_data): | |
pos_tag, neg_tag = 0, 1 | |
batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) | |
batch_short_term, batch_long_term, batch_context = list(batch_short_term), list(batch_long_term), list(batch_context) | |
batch_negative, batch_label, batch_label_neg = list(), list(), list() | |
for index in range(len(batch_short_term)): | |
short_term, long_term, context = batch_short_term[index], batch_long_term[index], batch_context[index] | |
batch_label.append(pos_tag) | |
candidates = [v[0] for v in self.kb[short_term]] | |
if len(candidates) == 1: | |
batch_negative.append(long_term) | |
batch_label_neg.append(pos_tag) | |
continue | |
negative = random_negative(long_term, candidates) | |
batch_negative.append(negative) | |
batch_label_neg.append(neg_tag) | |
prompt = batch_context + batch_context | |
long_terms = batch_long_term + batch_negative | |
label = batch_label + batch_label_neg | |
encoding = self.tokenizer(prompt, long_terms, return_tensors="pt", padding=True, truncation=True) | |
label = torch.LongTensor(label) | |
return encoding, label | |
def __call__(self, data_path): | |
dataset = load_dataset(data_path=data_path) | |
dataset = TextData(dataset) | |
train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // 2, shuffle=self.shuffle, | |
collate_fn=self.collate_fn) | |
return train_iterator | |
def mask_subword(subword_sequences, prob=0.15, masked_prob=0.8, VOCAB_SIZE=30522): | |
PAD, CLS, SEP, MASK, BLANK = 0, 101, 102, 103, -100 | |
masked_labels = list() | |
for sentence in subword_sequences: | |
labels = [BLANK for _ in range(len(sentence))] | |
original = sentence[:] | |
end = len(sentence) | |
if PAD in sentence: | |
end = sentence.index(PAD) | |
for pos in range(end): | |
if sentence[pos] in (CLS, SEP): continue | |
if rand() > prob: continue | |
if rand() < masked_prob: # 80% | |
sentence[pos] = MASK | |
elif rand() < 0.5: # 10% | |
sentence[pos] = random.randint(0, VOCAB_SIZE-1) | |
labels[pos] = original[pos] | |
masked_labels.append(labels) | |
return subword_sequences, masked_labels | |
class AcroBERTLoader(): | |
def __init__(self, batch_size, tokenizer, kb, shuffle=True, masked_prob=0.15, hard_num=2): | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.tokenizer = tokenizer | |
self.masked_prob = masked_prob | |
self.hard_num = hard_num | |
self.kb = kb | |
self.all_long_terms = list() | |
for vs in self.kb.values(): | |
self.all_long_terms.extend(list(vs)) | |
def select_negative(self, target): | |
selected, flag, max_time = None, True, 10 | |
if target in self.kb: | |
long_term_candidates = self.kb[target] | |
if len(long_term_candidates) == 1: | |
long_term_candidates = self.all_long_terms | |
else: | |
long_term_candidates = self.all_long_terms | |
attempt = 0 | |
while flag and attempt < max_time: | |
attempt += 1 | |
selected = random.choice(long_term_candidates) | |
if selected != target: | |
flag = False | |
if attempt == max_time: | |
selected = random.choice(self.all_long_terms) | |
return selected | |
def collate_fn(self, batch_data): | |
batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) | |
pos_samples, neg_samples, masked_pos_samples = list(), list(), list() | |
for _ in range(self.hard_num): | |
temp_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] | |
neg_long_terms = [self.select_negative(st) for st in batch_short_term] | |
temp_neg_samples = [neg_long_terms[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] | |
temp_masked_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] | |
pos_samples.extend(temp_pos_samples) | |
neg_samples.extend(temp_neg_samples) | |
masked_pos_samples.extend(temp_masked_pos_samples) | |
return pos_samples, masked_pos_samples, neg_samples | |
def __call__(self, data_path): | |
dataset = load_pretrain(data_path=data_path) | |
logger.info('loaded dataset, sample = {a}'.format(a=len(dataset['short_term']))) | |
dataset = TextData(dataset) | |
train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // (2 * self.hard_num), shuffle=self.shuffle, | |
collate_fn=self.collate_fn) | |
return train_iterator | |