|
|
|
|
|
|
|
|
|
|
|
""" |
|
Data pre-processing: build vocabularies and binarize training data. |
|
""" |
|
|
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
from collections import Counter |
|
from itertools import zip_longest |
|
from multiprocessing import Pool |
|
|
|
from fairseq import options, tasks, utils |
|
from fairseq.binarizer import Binarizer |
|
from fairseq.data import indexed_dataset |
|
from fairseq.file_chunker_utils import find_offsets |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
stream=sys.stdout, |
|
) |
|
logger = logging.getLogger("fairseq_cli.preprocess") |
|
|
|
|
|
def main(args): |
|
utils.import_user_module(args) |
|
|
|
os.makedirs(args.destdir, exist_ok=True) |
|
|
|
logger.addHandler( |
|
logging.FileHandler( |
|
filename=os.path.join(args.destdir, "preprocess.log"), |
|
) |
|
) |
|
logger.info(args) |
|
|
|
assert args.dataset_impl != "huffman", "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." |
|
|
|
task = tasks.get_task(args.task) |
|
|
|
def train_path(lang): |
|
return "{}{}".format(args.trainpref, ("." + lang) if lang else "") |
|
|
|
def file_name(prefix, lang): |
|
fname = prefix |
|
if lang is not None: |
|
fname += ".{lang}".format(lang=lang) |
|
return fname |
|
|
|
def dest_path(prefix, lang): |
|
return os.path.join(args.destdir, file_name(prefix, lang)) |
|
|
|
def dict_path(lang): |
|
return dest_path("dict", lang) + ".txt" |
|
|
|
def build_dictionary(filenames, src=False, tgt=False): |
|
assert src ^ tgt |
|
return task.build_dictionary( |
|
filenames, |
|
workers=args.workers, |
|
threshold=args.thresholdsrc if src else args.thresholdtgt, |
|
nwords=args.nwordssrc if src else args.nwordstgt, |
|
padding_factor=args.padding_factor, |
|
) |
|
|
|
target = not args.only_source |
|
|
|
if not args.srcdict and os.path.exists(dict_path(args.source_lang)): |
|
raise FileExistsError(dict_path(args.source_lang)) |
|
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)): |
|
raise FileExistsError(dict_path(args.target_lang)) |
|
|
|
if args.joined_dictionary: |
|
assert ( |
|
not args.srcdict or not args.tgtdict |
|
), "cannot use both --srcdict and --tgtdict with --joined-dictionary" |
|
|
|
if args.srcdict: |
|
src_dict = task.load_dictionary(args.srcdict) |
|
elif args.tgtdict: |
|
src_dict = task.load_dictionary(args.tgtdict) |
|
else: |
|
assert ( |
|
args.trainpref |
|
), "--trainpref must be set if --srcdict is not specified" |
|
src_dict = build_dictionary( |
|
{train_path(lang) for lang in [args.source_lang, args.target_lang]}, |
|
src=True, |
|
) |
|
tgt_dict = src_dict |
|
else: |
|
if args.srcdict: |
|
src_dict = task.load_dictionary(args.srcdict) |
|
else: |
|
assert ( |
|
args.trainpref |
|
), "--trainpref must be set if --srcdict is not specified" |
|
src_dict = build_dictionary([train_path(args.source_lang)], src=True) |
|
|
|
if target: |
|
if args.tgtdict: |
|
tgt_dict = task.load_dictionary(args.tgtdict) |
|
else: |
|
assert ( |
|
args.trainpref |
|
), "--trainpref must be set if --tgtdict is not specified" |
|
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) |
|
else: |
|
tgt_dict = None |
|
|
|
src_dict.save(dict_path(args.source_lang)) |
|
if target and tgt_dict is not None: |
|
tgt_dict.save(dict_path(args.target_lang)) |
|
|
|
if args.dict_only: |
|
return |
|
|
|
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): |
|
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) |
|
n_seq_tok = [0, 0] |
|
replaced = Counter() |
|
|
|
def merge_result(worker_result): |
|
replaced.update(worker_result["replaced"]) |
|
n_seq_tok[0] += worker_result["nseq"] |
|
n_seq_tok[1] += worker_result["ntok"] |
|
|
|
input_file = "{}{}".format( |
|
input_prefix, ("." + lang) if lang is not None else "" |
|
) |
|
offsets = find_offsets(input_file, num_workers) |
|
(first_chunk, *more_chunks) = zip(offsets, offsets[1:]) |
|
pool = None |
|
if num_workers > 1: |
|
pool = Pool(processes=num_workers - 1) |
|
for worker_id, (start_offset, end_offset) in enumerate( |
|
more_chunks, start=1 |
|
): |
|
prefix = "{}{}".format(output_prefix, worker_id) |
|
pool.apply_async( |
|
binarize, |
|
( |
|
args, |
|
input_file, |
|
vocab, |
|
prefix, |
|
lang, |
|
start_offset, |
|
end_offset, |
|
), |
|
callback=merge_result, |
|
) |
|
pool.close() |
|
|
|
ds = indexed_dataset.make_builder( |
|
dataset_dest_file(args, output_prefix, lang, "bin"), |
|
impl=args.dataset_impl, |
|
vocab_size=len(vocab), |
|
) |
|
merge_result( |
|
Binarizer.binarize( |
|
input_file, |
|
vocab, |
|
lambda t: ds.add_item(t), |
|
offset=first_chunk[0], |
|
end=first_chunk[1], |
|
) |
|
) |
|
if num_workers > 1: |
|
pool.join() |
|
for worker_id in range(1, num_workers): |
|
prefix = "{}{}".format(output_prefix, worker_id) |
|
temp_file_path = dataset_dest_prefix(args, prefix, lang) |
|
ds.merge_file_(temp_file_path) |
|
os.remove(indexed_dataset.data_file_path(temp_file_path)) |
|
os.remove(indexed_dataset.index_file_path(temp_file_path)) |
|
|
|
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) |
|
|
|
logger.info( |
|
"[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( |
|
lang, |
|
input_file, |
|
n_seq_tok[0], |
|
n_seq_tok[1], |
|
100 * sum(replaced.values()) / n_seq_tok[1], |
|
vocab.unk_word, |
|
) |
|
) |
|
|
|
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): |
|
nseq = [0] |
|
|
|
def merge_result(worker_result): |
|
nseq[0] += worker_result["nseq"] |
|
|
|
input_file = input_prefix |
|
offsets = find_offsets(input_file, num_workers) |
|
(first_chunk, *more_chunks) = zip(offsets, offsets[1:]) |
|
pool = None |
|
if num_workers > 1: |
|
pool = Pool(processes=num_workers - 1) |
|
for worker_id, (start_offset, end_offset) in enumerate( |
|
more_chunks, start=1 |
|
): |
|
prefix = "{}{}".format(output_prefix, worker_id) |
|
pool.apply_async( |
|
binarize_alignments, |
|
( |
|
args, |
|
input_file, |
|
utils.parse_alignment, |
|
prefix, |
|
start_offset, |
|
end_offset, |
|
), |
|
callback=merge_result, |
|
) |
|
pool.close() |
|
|
|
ds = indexed_dataset.make_builder( |
|
dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl |
|
) |
|
|
|
merge_result( |
|
Binarizer.binarize_alignments( |
|
input_file, |
|
utils.parse_alignment, |
|
lambda t: ds.add_item(t), |
|
offset=first_chunk[0], |
|
end=first_chunk[1], |
|
) |
|
) |
|
if num_workers > 1: |
|
pool.join() |
|
for worker_id in range(1, num_workers): |
|
prefix = "{}{}".format(output_prefix, worker_id) |
|
temp_file_path = dataset_dest_prefix(args, prefix, None) |
|
ds.merge_file_(temp_file_path) |
|
os.remove(indexed_dataset.data_file_path(temp_file_path)) |
|
os.remove(indexed_dataset.index_file_path(temp_file_path)) |
|
|
|
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) |
|
|
|
logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0])) |
|
|
|
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): |
|
if args.dataset_impl == "raw": |
|
|
|
output_text_file = dest_path( |
|
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), |
|
lang, |
|
) |
|
shutil.copyfile(file_name(input_prefix, lang), output_text_file) |
|
else: |
|
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) |
|
|
|
def make_all(lang, vocab): |
|
if args.trainpref: |
|
make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) |
|
if args.validpref: |
|
for k, validpref in enumerate(args.validpref.split(",")): |
|
outprefix = "valid{}".format(k) if k > 0 else "valid" |
|
make_dataset( |
|
vocab, validpref, outprefix, lang, num_workers=args.workers |
|
) |
|
if args.testpref: |
|
for k, testpref in enumerate(args.testpref.split(",")): |
|
outprefix = "test{}".format(k) if k > 0 else "test" |
|
make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) |
|
|
|
def make_all_alignments(): |
|
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): |
|
make_binary_alignment_dataset( |
|
args.trainpref + "." + args.align_suffix, |
|
"train.align", |
|
num_workers=args.workers, |
|
) |
|
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): |
|
make_binary_alignment_dataset( |
|
args.validpref + "." + args.align_suffix, |
|
"valid.align", |
|
num_workers=args.workers, |
|
) |
|
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): |
|
make_binary_alignment_dataset( |
|
args.testpref + "." + args.align_suffix, |
|
"test.align", |
|
num_workers=args.workers, |
|
) |
|
|
|
make_all(args.source_lang, src_dict) |
|
if target: |
|
make_all(args.target_lang, tgt_dict) |
|
if args.align_suffix: |
|
make_all_alignments() |
|
|
|
logger.info("Wrote preprocessed data to {}".format(args.destdir)) |
|
|
|
if args.alignfile: |
|
assert args.trainpref, "--trainpref must be set if --alignfile is specified" |
|
src_file_name = train_path(args.source_lang) |
|
tgt_file_name = train_path(args.target_lang) |
|
freq_map = {} |
|
with open(args.alignfile, "r", encoding="utf-8") as align_file: |
|
with open(src_file_name, "r", encoding="utf-8") as src_file: |
|
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: |
|
for a, s, t in zip_longest(align_file, src_file, tgt_file): |
|
si = src_dict.encode_line(s, add_if_not_exist=False) |
|
ti = tgt_dict.encode_line(t, add_if_not_exist=False) |
|
ai = list(map(lambda x: tuple(x.split("-")), a.split())) |
|
for sai, tai in ai: |
|
srcidx = si[int(sai)] |
|
tgtidx = ti[int(tai)] |
|
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): |
|
assert srcidx != src_dict.pad() |
|
assert srcidx != src_dict.eos() |
|
assert tgtidx != tgt_dict.pad() |
|
assert tgtidx != tgt_dict.eos() |
|
|
|
if srcidx not in freq_map: |
|
freq_map[srcidx] = {} |
|
if tgtidx not in freq_map[srcidx]: |
|
freq_map[srcidx][tgtidx] = 1 |
|
else: |
|
freq_map[srcidx][tgtidx] += 1 |
|
|
|
align_dict = {} |
|
for srcidx in freq_map.keys(): |
|
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) |
|
|
|
with open( |
|
os.path.join( |
|
args.destdir, |
|
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang), |
|
), |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
for k, v in align_dict.items(): |
|
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) |
|
|
|
|
|
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): |
|
ds = indexed_dataset.make_builder( |
|
dataset_dest_file(args, output_prefix, lang, "bin"), |
|
impl=args.dataset_impl, |
|
vocab_size=len(vocab), |
|
) |
|
|
|
def consumer(tensor): |
|
ds.add_item(tensor) |
|
|
|
res = Binarizer.binarize( |
|
filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end |
|
) |
|
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) |
|
return res |
|
|
|
|
|
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end): |
|
ds = indexed_dataset.make_builder( |
|
dataset_dest_file(args, output_prefix, None, "bin"), |
|
impl=args.dataset_impl, |
|
vocab_size=None, |
|
) |
|
|
|
def consumer(tensor): |
|
ds.add_item(tensor) |
|
|
|
res = Binarizer.binarize_alignments( |
|
filename, parse_alignment, consumer, offset=offset, end=end |
|
) |
|
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) |
|
return res |
|
|
|
|
|
def dataset_dest_prefix(args, output_prefix, lang): |
|
base = "{}/{}".format(args.destdir, output_prefix) |
|
if lang is not None: |
|
lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) |
|
elif args.only_source: |
|
lang_part = "" |
|
else: |
|
lang_part = ".{}-{}".format(args.source_lang, args.target_lang) |
|
|
|
return "{}{}".format(base, lang_part) |
|
|
|
|
|
def dataset_dest_file(args, output_prefix, lang, extension): |
|
base = dataset_dest_prefix(args, output_prefix, lang) |
|
return "{}.{}".format(base, extension) |
|
|
|
|
|
def cli_main(): |
|
parser = options.get_preprocessing_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|