Tzktz's picture
Upload 7664 files
6fc683c verified
raw
history blame
24 kB
from __future__ import absolute_import, division, print_function
import logging
import os
import json
import random
import glob
import re
import torch
import tqdm
import torch.utils.data
logger = logging.getLogger(__name__)
class Seq2seqDatasetForBert(torch.utils.data.Dataset):
def __init__(
self, features, max_source_len, max_target_len,
vocab_size, cls_id, sep_id, pad_id, mask_id,
random_prob, keep_prob, offset, num_training_instances,
span_len=1, span_prob=1.0):
self.features = features
self.max_source_len = max_source_len
self.max_target_len = max_target_len
self.offset = offset
if offset > 0:
logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset)
self.cls_id = cls_id
self.sep_id = sep_id
self.pad_id = pad_id
self.random_prob = random_prob
self.keep_prob = keep_prob
self.mask_id = mask_id
self.vocab_size = vocab_size
self.num_training_instances = num_training_instances
self.span_len = span_len
self.span_prob = span_prob
def __len__(self):
return int(self.num_training_instances)
def __trunk(self, ids, max_len):
if len(ids) > max_len - 1:
ids = ids[:max_len - 1]
ids = ids + [self.sep_id]
return ids
def __pad(self, ids, max_len):
if len(ids) < max_len:
return ids + [self.pad_id] * (max_len - len(ids))
else:
assert len(ids) == max_len
return ids
def __getitem__(self, idx):
idx = (self.offset + idx) % len(self.features)
feature = self.features[idx]
source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len)
target_ids = self.__trunk(feature["target_ids"], self.max_target_len)
pseudo_ids = []
for tk_id in target_ids:
p = random.random()
if p < self.keep_prob:
pseudo_ids.append(tk_id)
elif p < self.keep_prob + self.random_prob:
pseudo_ids.append(random.randint(0, self.vocab_size - 1))
else:
pseudo_ids.append(self.mask_id)
num_source_tokens = len(source_ids)
num_target_tokens = len(target_ids)
source_ids = self.__pad(source_ids, self.max_source_len)
target_ids = self.__pad(target_ids, self.max_target_len)
pseudo_ids = self.__pad(pseudo_ids, self.max_target_len)
if self.span_len > 1:
span_ids = []
span_id = 1
while len(span_ids) < num_target_tokens:
p = random.random()
if p < self.span_prob:
span_len = random.randint(2, self.span_len)
span_len = min(span_len, num_target_tokens - len(span_ids))
else:
span_len = 1
span_ids.extend([span_id] * span_len)
span_id += 1
span_ids = self.__pad(span_ids, self.max_target_len)
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids
else:
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens
# DONE: finish this!!! the 2D input id settings.
class Seq2seqDatasetForLayoutlm(torch.utils.data.Dataset):
def __init__(
self, features, max_source_len, max_target_len,
vocab_size, cls_id, sep_id, pad_id, mask_id,
random_prob, keep_prob, offset, num_training_instances, layout_flag=True,
span_len=1, span_prob=1.0):
self.layout_flag = layout_flag
self.features = features
self.max_source_len = max_source_len
self.max_target_len = max_target_len
self.offset = offset
if offset > 0:
logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset)
self.cls_id = cls_id
self.sep_id = sep_id
self.pad_id = pad_id
self.random_prob = random_prob
self.keep_prob = keep_prob
self.mask_id = mask_id
self.vocab_size = vocab_size
self.num_training_instances = num_training_instances
self.span_len = span_len
self.span_prob = span_prob
self.index_sp_id = 0
def __len__(self):
return int(self.num_training_instances)
def __clip_index(self, ids):
replace_value = 0
for i in range(len(ids)):
if ids[i] > self.max_source_len - 1:
ids[i] = replace_value
return ids
def __trunk(self, ids, max_len, simple=False, value=None):
trunk_value = value if value is not None else self.sep_id
if len(ids) > max_len - 1:
ids = ids[:max_len - 1]
if simple:
ids = ids + [trunk_value]
else:
ids = ids + [[trunk_value, 1000, 1000, 1000, 1000]]
return ids
def __pad(self, ids, max_len, simple=False, value=None):
pad_value = value if value is not None else self.pad_id
if len(ids) < max_len:
if simple:
return ids + [pad_value] * (max_len - len(ids))
else:
return ids + [[pad_value, 0, 0, 0, 0]] * (max_len - len(ids))
else:
assert len(ids) == max_len
return ids
def __getitem__(self, idx):
if self.layout_flag:
return self.__getitem_layout__(idx)
else:
return self.__getitem_bert__(idx)
def __getitem_bert__(self, idx):
idx = (self.offset + idx) % len(self.features)
feature = self.features[idx]
source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len, simple=True)
target_ids = self.__trunk(feature["target_ids"], self.max_target_len, simple=True)
target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id)
pseudo_ids = []
for tk_id in target_ids:
p = random.random()
if p < self.keep_prob:
pseudo_ids.append(tk_id)
elif p < self.keep_prob + self.random_prob:
pseudo_ids.append(random.randint(0, self.vocab_size - 1))
else:
pseudo_ids.append(self.mask_id)
num_source_tokens = len(source_ids)
num_target_tokens = len(target_ids)
source_ids = self.__pad(source_ids, self.max_source_len, simple=True)
target_ids = self.__pad(target_ids, self.max_target_len, simple=True)
pseudo_ids = self.__pad(pseudo_ids, self.max_target_len, simple=True)
target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id)
target_index = self.__clip_index(target_index)
if self.span_len > 1:
span_ids = []
span_id = 1
while len(span_ids) < num_target_tokens:
p = random.random()
if p < self.span_prob:
span_len = random.randint(2, self.span_len)
span_len = min(span_len, num_target_tokens - len(span_ids))
else:
span_len = 1
span_ids.extend([span_id] * span_len)
span_id += 1
span_ids = self.__pad(span_ids, self.max_target_len)
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index
else:
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index
def __getitem_layout__(self, idx):
# TODO: how to initialize the random and masked tokens' pos emb
# Simple Solution: only mask the text
idx = (self.offset + idx) % len(self.features)
feature = self.features[idx]
source_ids = self.__trunk([[self.cls_id, 0, 0, 0, 0]] + feature["source_ids"], self.max_source_len)
target_ids = self.__trunk(feature["target_ids"], self.max_target_len)
target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id)
pseudo_ids = []
for tk_id in target_ids:
p = random.random()
if p < self.keep_prob:
pseudo_ids.append(tk_id)
elif p < self.keep_prob + self.random_prob:
pseudo_ids.append([random.randint(0, self.vocab_size - 1)] + [0, 0, 0, 0]) # tk_id[1:])
else:
pseudo_ids.append([self.mask_id] + [0, 0, 0, 0]) # tk_id[1:])
num_source_tokens = len(source_ids)
num_target_tokens = len(target_ids)
source_ids = self.__pad(source_ids, self.max_source_len)
target_ids = self.__pad(target_ids, self.max_target_len)
pseudo_ids = self.__pad(pseudo_ids, self.max_target_len)
target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id)
target_index = self.__clip_index(target_index)
if self.span_len > 1:
span_ids = []
span_id = 1
while len(span_ids) < num_target_tokens:
p = random.random()
if p < self.span_prob:
span_len = random.randint(2, self.span_len)
span_len = min(span_len, num_target_tokens - len(span_ids))
else:
span_len = 1
span_ids.extend([span_id] * span_len)
span_id += 1
span_ids = self.__pad(span_ids, self.max_target_len)
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index
else:
return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index
def batch_list_to_batch_tensors(batch):
batch_tensors = []
for x in zip(*batch):
if isinstance(x[0], torch.Tensor):
batch_tensors.append(torch.stack(x))
else:
batch_tensors.append(torch.tensor(x, dtype=torch.long))
return batch_tensors
def get_max_epoch_model(output_dir):
fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin"))
fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin"))
if (not fn_model_list) or (not fn_optim_list):
return None
os.path.basename(output_dir)
both_set = set([int(os.path.basename(fn).split('.')[1]) for fn in fn_model_list]
) & set([int(os.path.basename(fn).split('.')[1]) for fn in fn_optim_list])
if both_set:
return max(both_set)
else:
return None
def load_and_cache_examples(
example_file, tokenizer, local_rank, cached_features_file, shuffle=True):
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank not in [-1, 0]:
torch.distributed.barrier()
if cached_features_file is not None and os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", example_file)
examples = []
with open(example_file, mode="r", encoding="utf-8") as reader:
for i, line in enumerate(reader):
if i == 100:
break
examples.append(json.loads(line))
features = []
for example in tqdm.tqdm(examples):
if isinstance(example["src"], list):
source_tokens = example["src"]
target_tokens = example["tgt"]
else:
source_tokens = tokenizer.tokenize(example["src"])
target_tokens = tokenizer.tokenize(example["tgt"])
features.append({
"source_ids": tokenizer.convert_tokens_to_ids(source_tokens),
"target_ids": tokenizer.convert_tokens_to_ids(target_tokens),
})
if shuffle:
random.shuffle(features)
if local_rank in [-1, 0] and cached_features_file is not None:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank == 0:
torch.distributed.barrier()
return features
def load_and_cache_line_order_examples(
example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024,
layout_flag=True, shuffle=True,
src_shuffle_rate=0,
file_info_flag=False,
):
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank not in [-1, 0]:
torch.distributed.barrier()
if cached_features_file is not None and os.path.exists(cached_features_file) and False:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset at %s", example_path)
examples = []
with open(example_path, 'r') as layout_reader:
logger.info(f'Start loading {example_path}')
for i, line in enumerate(layout_reader):
examples.append(json.loads(line))
features = []
for layout in tqdm.tqdm(examples):
bleu = layout['bleu']
if random.random() < src_shuffle_rate:
# print('Random!!!')
# DONE: the random src! here has bug! index also need shuffle
src_layout = layout['src']
tgt_index = layout['tgt_index']
source_length = len(src_layout)
shuffle_index = list(range(source_length))
random.shuffle(shuffle_index)
shuffle_layout = ['' for _ in range(source_length)]
for i, j in enumerate(shuffle_index):
# NOTE: map i-th token to j-th token
shuffle_layout[j] = src_layout[i]
shuffle_target_index = [shuffle_index[i] for i in tgt_index]
layout['tgt_index'] = shuffle_target_index
layout['src'] = shuffle_layout
mask = tokenizer.mask_token_id
src_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + src_layout for tmp_i, src_layout in enumerate(layout['src'])]
tgt_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + tgt_layout for tmp_i, tgt_layout in enumerate(layout['tgt'])]
tgt_index = layout['tgt_index']
feature = {
"source_ids": src_ids,
"target_ids": tgt_ids,
"target_index": tgt_index,
'bleu': bleu
}
if file_info_flag:
file_info = {'original_filename': layout['filename'], 'filename': layout['filename'],
'page_idx': 0}
feature['file_info'] = file_info
features.append(feature)
if shuffle:
random.shuffle(features)
if local_rank in [-1, 0] and cached_features_file is not None:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank == 0:
torch.distributed.barrier()
return features
def load_and_cache_layoutlm_examples(
example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024,
layout_flag=True, shuffle=True,
src_shuffle_rate=0,
file_info_flag=False
):
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank not in [-1, 0]:
torch.distributed.barrier()
if cached_features_file is not None and os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset at %s", example_path)
examples = []
if os.path.isdir(example_path):
text_files = glob.glob(f'{example_path}/*text*.json')
layout_files = [re.sub('text|txt', 'layout', x, 1) for x in text_files]
else:
text_files = [example_path]
layout_files = [re.sub('text|txt', 'layout', example_path, 1)]
for text_file, layout_file in zip(text_files, layout_files):
with open(text_file, mode='r', encoding='utf-8') as text_reader, \
open(layout_file, mode='r', encoding='utf-8') as layout_reader:
logger.info(f'Start loading {text_file}')
for i, (text_line, layout_line) in enumerate(zip(text_reader, layout_reader)):
if (i + 1) % 10000 == 0:
logger.info(f'{i + 1} lines ...')
examples.append((json.loads(text_line), json.loads(layout_line)))
features = []
def tokenize_text_and_layout_src(_text, _layout, _layout_flag):
ret = []
index_split = {}
words = _text.split()
# note: (OLD) the index should start from 1: 0-the cls token in src
# note: (NEW) we need to remove the src embedding's CLS SEP token so we can still start from 0
# note: (NEWER) we need to at least one blank pos for ignore index in loss function (we use sep's index)
# NOTE: (NEWER-ER) 1 for all padding tgt index
new_token_index = 1 # first ordinary index
for i, (word, box) in enumerate(zip(words, _layout)):
if (not box[2] >= box[0]) or (not box[3] >= box[1]):
continue
tokens = tokenizer.tokenize(word)
tokens = tokenizer.convert_tokens_to_ids(tokens)
new_token_ids = []
for token in tokens:
if _layout_flag:
ret.append([token] + box)
else:
ret.append(token)
new_token_ids.append(new_token_index)
new_token_index += 1
index_split[i] = new_token_ids
return ret, index_split
def tokenize_text_and_layout_tgt(_text, _layout, _index, _index_split, _layout_flag):
ret = []
ret_index = []
words = _text.split()
for word, box, i in zip(words, _layout, _index):
if (not box[2] >= box[0]) or (not box[3] >= box[1]):
continue
tokens = tokenizer.tokenize(word)
tokens = tokenizer.convert_tokens_to_ids(tokens)
for token, ii in zip(tokens, _index_split[i]):
if _layout_flag:
ret.append([token] + box)
else:
ret.append(token)
ii = min(ii, max_src_length - 1)
ret_index.append(ii)
return ret, ret_index
for text, layout in tqdm.tqdm(examples):
if 'bleu' in text:
bleu = text['bleu']
else:
bleu = 0
if random.random() < src_shuffle_rate:
# print('Random!!!')
# DONE: the random src! here has bug! index also need shuffle
src_text = text['src']
src_layout = layout['src']
tgt_index = text['tgt_index']
src_text = src_text.split()
source_length = len(src_text)
shuffle_index = list(range(source_length))
random.shuffle(shuffle_index)
shuffle_text = ['' for _ in range(source_length)]
shuffle_layout = ['' for _ in range(source_length)]
for i, j in enumerate(shuffle_index):
# NOTE: map i-th token to j-th token
shuffle_text[j] = src_text[i]
shuffle_layout[j] = src_layout[i]
shuffle_target_index = [shuffle_index[i] for i in tgt_index]
text['src'] = ' '.join(shuffle_text)
text['tgt_index'] = shuffle_target_index
layout['src'] = shuffle_layout
src_ids, src_index_split = tokenize_text_and_layout_src(text['src'], layout['src'],
_layout_flag=layout_flag)
tgt_ids, tgt_index = tokenize_text_and_layout_tgt(text['tgt'], layout['tgt'], text['tgt_index'],
src_index_split, _layout_flag=layout_flag)
feature = {
"source_ids": src_ids,
"target_ids": tgt_ids,
"target_index": tgt_index,
'bleu': bleu
}
if file_info_flag:
file_info = {'original_filename': text['original_filename'], 'filename': text['filename'], 'page_idx': text['page_idx']}
feature['file_info'] = file_info
features.append(feature)
if shuffle:
random.shuffle(features)
if local_rank in [-1, 0] and cached_features_file is not None:
if not os.path.exists(os.path.dirname(cached_features_file)):
os.makedirs(os.path.dirname(cached_features_file))
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
if local_rank == 0:
torch.distributed.barrier()
return features
def convert_src_layout_inputs_to_tokens(inputs, converter, max_src_length, layout_flag=True):
ret = []
if not layout_flag:
for line in inputs:
ret.append(converter(line["source_ids"])[: max_src_length])
else:
for line in inputs:
raw_text_ids = [x[0] for x in line['source_ids']]
raw_text = converter(raw_text_ids)
new_line = [[t] + x[1:] for t, x in zip(raw_text, line['source_ids'])][: max_src_length]
ret.append(new_line)
return ret
def convert_tgt_layout_inputs_to_tokens(inputs, converter, max_tgt_length, layout_flag=True):
ret = []
if not layout_flag:
for line in inputs:
ret.append(converter(line["target_ids"])[: max_tgt_length])
else:
for line in inputs:
raw_text_ids = [x[0] for x in line['target_ids']]
ret.append(converter(raw_text_ids)[: max_tgt_length])
return ret
def get_tokens_from_src_and_index(src, index, modifier=None):
result = []
for i in index:
i = modifier(i)
i = min(i, len(src) - 1)
if isinstance(src[i], list):
result.append(src[i][0])
else:
result.append(src[i])
return result
def get_layout_from_src_and_index(src, index, modifier=None):
result = []
s = set()
for i in index:
i = modifier(i)
i = min(i, len(src) - 1)
layout = src[i][1:]
if repr(layout) not in s:
result.append(layout)
s.add(repr(layout))
return result
def get_everything_from_src_and_index(src, index, modifier=None):
result = []
for i in index:
i = modifier(i)
i = min(i, len(src) - 1)
result.append(src[i])
return result