JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
6.94 kB
import shutil
import os, sys
from subprocess import check_call, check_output
import glob
import argparse
import shutil
import pathlib
import itertools
def call_output(cmd):
print(f"Executing: {cmd}")
ret = check_output(cmd, shell=True)
print(ret)
return ret
def call(cmd):
print(cmd)
check_call(cmd, shell=True)
WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None)
if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip():
print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."')
sys.exit(-1)
SPM_PATH = os.environ.get('SPM_PATH', None)
if SPM_PATH is None or not SPM_PATH.strip():
print("Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting...")
sys.exit(-1)
SPM_MODEL = f'{WORKDIR_ROOT}/sentence.bpe.model'
SPM_VOCAB = f'{WORKDIR_ROOT}/dict_250k.txt'
SPM_ENCODE = f'{SPM_PATH}'
if not os.path.exists(SPM_MODEL):
call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/sentence.bpe.model -O {SPM_MODEL}")
if not os.path.exists(SPM_VOCAB):
call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/dict_250k.txt -O {SPM_VOCAB}")
def get_data_size(raw):
cmd = f'wc -l {raw}'
ret = call_output(cmd)
return int(ret.split()[0])
def encode_spm(model, direction, prefix='', splits=['train', 'test', 'valid'], pairs_per_shard=None):
src, tgt = direction.split('-')
for split in splits:
src_raw, tgt_raw = f'{RAW_DIR}/{split}{prefix}.{direction}.{src}', f'{RAW_DIR}/{split}{prefix}.{direction}.{tgt}'
if os.path.exists(src_raw) and os.path.exists(tgt_raw):
cmd = f"""python {SPM_ENCODE} \
--model {model}\
--output_format=piece \
--inputs {src_raw} {tgt_raw} \
--outputs {BPE_DIR}/{direction}{prefix}/{split}.bpe.{src} {BPE_DIR}/{direction}{prefix}/{split}.bpe.{tgt} """
print(cmd)
call(cmd)
def binarize_(
bpe_dir,
databin_dir,
direction, spm_vocab=SPM_VOCAB,
splits=['train', 'test', 'valid'],
):
src, tgt = direction.split('-')
try:
shutil.rmtree(f'{databin_dir}', ignore_errors=True)
os.mkdir(f'{databin_dir}')
except OSError as error:
print(error)
cmds = [
"fairseq-preprocess",
f"--source-lang {src} --target-lang {tgt}",
f"--destdir {databin_dir}/",
f"--workers 8",
]
if isinstance(spm_vocab, tuple):
src_vocab, tgt_vocab = spm_vocab
cmds.extend(
[
f"--srcdict {src_vocab}",
f"--tgtdict {tgt_vocab}",
]
)
else:
cmds.extend(
[
f"--joined-dictionary",
f"--srcdict {spm_vocab}",
]
)
input_options = []
if 'train' in splits and glob.glob(f"{bpe_dir}/train.bpe*"):
input_options.append(
f"--trainpref {bpe_dir}/train.bpe",
)
if 'valid' in splits and glob.glob(f"{bpe_dir}/valid.bpe*"):
input_options.append(f"--validpref {bpe_dir}/valid.bpe")
if 'test' in splits and glob.glob(f"{bpe_dir}/test.bpe*"):
input_options.append(f"--testpref {bpe_dir}/test.bpe")
if len(input_options) > 0:
cmd = " ".join(cmds + input_options)
print(cmd)
call(cmd)
def binarize(
databin_dir,
direction, spm_vocab=SPM_VOCAB, prefix='',
splits=['train', 'test', 'valid'],
pairs_per_shard=None,
):
def move_databin_files(from_folder, to_folder):
for bin_file in glob.glob(f"{from_folder}/*.bin") \
+ glob.glob(f"{from_folder}/*.idx") \
+ glob.glob(f"{from_folder}/dict*"):
try:
shutil.move(bin_file, to_folder)
except OSError as error:
print(error)
bpe_databin_dir = f"{BPE_DIR}/{direction}{prefix}_databin"
bpe_dir = f"{BPE_DIR}/{direction}{prefix}"
if pairs_per_shard is None:
binarize_(bpe_dir, bpe_databin_dir, direction, spm_vocab=spm_vocab, splits=splits)
move_databin_files(bpe_databin_dir, databin_dir)
else:
# binarize valid and test which will not be sharded
binarize_(
bpe_dir, bpe_databin_dir, direction,
spm_vocab=spm_vocab, splits=[s for s in splits if s != "train"])
for shard_bpe_dir in glob.glob(f"{bpe_dir}/shard*"):
path_strs = os.path.split(shard_bpe_dir)
shard_str = path_strs[-1]
shard_folder = f"{bpe_databin_dir}/{shard_str}"
databin_shard_folder = f"{databin_dir}/{shard_str}"
print(f'working from {shard_folder} to {databin_shard_folder}')
os.makedirs(databin_shard_folder, exist_ok=True)
binarize_(
shard_bpe_dir, shard_folder, direction,
spm_vocab=spm_vocab, splits=["train"])
for test_data in glob.glob(f"{bpe_databin_dir}/valid.*") + glob.glob(f"{bpe_databin_dir}/test.*"):
filename = os.path.split(test_data)[-1]
try:
os.symlink(test_data, f"{databin_shard_folder}/{filename}")
except OSError as error:
print(error)
move_databin_files(shard_folder, databin_shard_folder)
def load_langs(path):
with open(path) as fr:
langs = [l.strip() for l in fr]
return langs
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", default=f"{WORKDIR_ROOT}/ML50")
parser.add_argument("--raw-folder", default='raw')
parser.add_argument("--bpe-folder", default='bpe')
parser.add_argument("--databin-folder", default='databin')
args = parser.parse_args()
DATA_PATH = args.data_root #'/private/home/yuqtang/public_data/ML50'
RAW_DIR = f'{DATA_PATH}/{args.raw_folder}'
BPE_DIR = f'{DATA_PATH}/{args.bpe_folder}'
DATABIN_DIR = f'{DATA_PATH}/{args.databin_folder}'
os.makedirs(BPE_DIR, exist_ok=True)
raw_files = itertools.chain(
glob.glob(f'{RAW_DIR}/train*'),
glob.glob(f'{RAW_DIR}/valid*'),
glob.glob(f'{RAW_DIR}/test*'),
)
directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files]
for direction in directions:
prefix = ""
splits = ['train', 'valid', 'test']
try:
shutil.rmtree(f'{BPE_DIR}/{direction}{prefix}', ignore_errors=True)
os.mkdir(f'{BPE_DIR}/{direction}{prefix}')
os.makedirs(DATABIN_DIR, exist_ok=True)
except OSError as error:
print(error)
spm_model, spm_vocab = SPM_MODEL, SPM_VOCAB
encode_spm(spm_model, direction=direction, splits=splits)
binarize(DATABIN_DIR, direction, spm_vocab=spm_vocab, splits=splits)