bert / bert_dataset.py
pt-sk's picture
Upload 7 files
28dc58b verified
raw
history blame contribute delete
No virus
4.33 kB
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import BertTokenizer
from data import get_data
import itertools
tokenizer = BertTokenizer.from_pretrained("bert-it-1/bert-it-vocab.txt")
class BERTDataset(Dataset):
def __init__(self, tokenizer: BertTokenizer=tokenizer, data_pair: list=get_data('datasets/movie_conversations.txt', "datasets/movie_lines.txt"), seq_len: int=128) -> None:
super().__init__()
self.tokenizer = tokenizer
self.seq_len = seq_len
self.corpus_lines = len(data_pair)
self.lines = data_pair
def __len__(self):
return self.corpus_lines
def __getitem__(self, item):
# Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
t1, t2, is_next_label = self.get_sent(item)
# Step 2: replace random words in sentence with mask / random words
t1_random, t1_label = self.random_word(t1)
t2_random, t2_label = self.random_word(t2)
# Step 3: Adding CLS and SEP tokens to the start and end of sentences
# Adding PAD token for labels
t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]
# Step 4: combine sentence 1 and 2 as one input
# adding PAD tokens to make the sentence same length as seq_len
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len]
bert_label = (t1_label + t2_label)[:self.seq_len]
padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
output = {"bert_input": bert_input,
"bert_label": bert_label,
"segment_label": segment_label,
"is_next": is_next_label}
return {key: torch.tensor(value) for key, value in output.items()}
def random_word(self, sentence):
tokens = sentence.split()
output_label = []
output = []
# 15% of the tokens would be replaced
for i, token in enumerate(tokens):
prob = random.random()
# remove cls and sep token
token_id = self.tokenizer(token)['input_ids'][1:-1]
if prob < 0.15:
prob /= 0.15
# 80% chance change token to mask token
if prob < 0.8:
for i in range(len(token_id)):
output.append(self.tokenizer.vocab['[MASK]'])
# 10% chance change token to random token
elif prob < 0.9:
for i in range(len(token_id)):
output.append(random.randrange(len(self.tokenizer.vocab)))
# 10% chance change token to current token
else:
output.append(token_id)
output_label.append(token_id)
else:
output.append(token_id)
for i in range(len(token_id)):
output_label.append(0)
# flattening
output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
assert len(output) == len(output_label)
return output, output_label
def get_sent(self, index):
'''return random sentence pair'''
t1, t2 = self.get_corpus_line(index)
# negative or positive pair, for next sentence prediction
if random.random() > 0.5:
return t1, t2, 1
else:
return t1, self.get_random_line(), 0
def get_corpus_line(self, item):
'''return sentence pair'''
return self.lines[item][0], self.lines[item][1]
def get_random_line(self):
'''return random single sentence'''
return self.lines[random.randrange(len(self.lines))][1]