JustinLin610
update
10b0761
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import fileinput
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(
description=(
"Extract back-translations from the stdout of fairseq-generate. "
"If there are multiply hypotheses for a source, we only keep the first one. "
)
)
parser.add_argument("--output", required=True, help="output prefix")
parser.add_argument(
"--srclang", required=True, help="source language (extracted from H-* lines)"
)
parser.add_argument(
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
)
parser.add_argument("--minlen", type=int, help="min length filter")
parser.add_argument("--maxlen", type=int, help="max length filter")
parser.add_argument("--ratio", type=float, help="ratio filter")
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
def validate(src, tgt):
srclen = len(src.split(" ")) if src != "" else 0
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
if (
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
or (
args.maxlen is not None
and (srclen > args.maxlen or tgtlen > args.maxlen)
)
or (
args.ratio is not None
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
)
):
return False
return True
def safe_index(toks, index, default):
try:
return toks[index]
except IndexError:
return default
with open(args.output + "." + args.srclang, "w") as src_h, open(
args.output + "." + args.tgtlang, "w"
) as tgt_h:
for line in tqdm(fileinput.input(args.files)):
if line.startswith("S-"):
tgt = safe_index(line.rstrip().split("\t"), 1, "")
elif line.startswith("H-"):
if tgt is not None:
src = safe_index(line.rstrip().split("\t"), 2, "")
if validate(src, tgt):
print(src, file=src_h)
print(tgt, file=tgt_h)
tgt = None
if __name__ == "__main__":
main()