|
import argparse |
|
import torch |
|
import glob |
|
import os |
|
import numpy as np |
|
|
|
|
|
class MMapIndexDataset(): |
|
def __init__(self, datapath): |
|
self.idxfp = np.load(datapath + '.npy', mmap_mode='r') |
|
self.binfp = np.memmap(datapath + '.bin', dtype='long', mode='r') |
|
|
|
def __len__(self): |
|
return self.idxfp.shape[0] |
|
|
|
def __getitem__(self, idx): |
|
return self.binfp[self.idxfp[idx, 0]:self.idxfp[idx, 1]] |
|
|
|
|
|
def convert_py_to_npy(input_tensor, bin_out, idx_out): |
|
idx = torch.empty(len(input_tensor), 2, dtype=torch.long) |
|
start = 0 |
|
for i, input in enumerate(input_tensor): |
|
idx[i] = torch.tensor([start, start + len(input)]) |
|
start += len(input) |
|
np.save(idx_out, idx) |
|
binfp = np.memmap(bin_out, dtype='long', mode='w+', shape=(start)) |
|
start = 0 |
|
for i, input in enumerate(input_tensor): |
|
for j, idx in enumerate(input): |
|
binfp[start + j] = idx |
|
start += len(input) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="Text infilling.") |
|
parser.add_argument('--data_path', type=str, |
|
default='/cognitive_comp/gaoxinyu/data/wudao') |
|
args = parser.parse_args() |
|
process_key = [ |
|
'incorrect_input_ids_list', |
|
'label_ids_list', |
|
'target_ids_list', |
|
] |
|
if os.path.exists(args.data_path): |
|
print(f'''Loading data from {args.data_path}''') |
|
data_dict = torch.load(args.data_path) |
|
for k in process_key: |
|
bin_out = ('_' + k + '.bin').join(args.data_path.rsplit('.pt', 1)) |
|
idx_out = ('_' + k).join(args.data_path.rsplit('.pt', 1)) |
|
convert_py_to_npy(data_dict[k], bin_out, idx_out) |
|
else: |
|
print( |
|
f'Please create the synthetic datafile {args.data_path} with create_synthetic_data.py.') |
|
|