|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import fileinput |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Extract back-translations from the stdout of fairseq-generate. " |
|
"If there are multiply hypotheses for a source, we only keep the first one. " |
|
) |
|
) |
|
parser.add_argument("--output", required=True, help="output prefix") |
|
parser.add_argument( |
|
"--srclang", required=True, help="source language (extracted from H-* lines)" |
|
) |
|
parser.add_argument( |
|
"--tgtlang", required=True, help="target language (extracted from S-* lines)" |
|
) |
|
parser.add_argument("--minlen", type=int, help="min length filter") |
|
parser.add_argument("--maxlen", type=int, help="max length filter") |
|
parser.add_argument("--ratio", type=float, help="ratio filter") |
|
parser.add_argument("files", nargs="*", help="input files") |
|
args = parser.parse_args() |
|
|
|
def validate(src, tgt): |
|
srclen = len(src.split(" ")) if src != "" else 0 |
|
tgtlen = len(tgt.split(" ")) if tgt != "" else 0 |
|
if ( |
|
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) |
|
or ( |
|
args.maxlen is not None |
|
and (srclen > args.maxlen or tgtlen > args.maxlen) |
|
) |
|
or ( |
|
args.ratio is not None |
|
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio) |
|
) |
|
): |
|
return False |
|
return True |
|
|
|
def safe_index(toks, index, default): |
|
try: |
|
return toks[index] |
|
except IndexError: |
|
return default |
|
|
|
with open(args.output + "." + args.srclang, "w") as src_h, open( |
|
args.output + "." + args.tgtlang, "w" |
|
) as tgt_h: |
|
for line in tqdm(fileinput.input(args.files)): |
|
if line.startswith("S-"): |
|
tgt = safe_index(line.rstrip().split("\t"), 1, "") |
|
elif line.startswith("H-"): |
|
if tgt is not None: |
|
src = safe_index(line.rstrip().split("\t"), 2, "") |
|
if validate(src, tgt): |
|
print(src, file=src_h) |
|
print(tgt, file=tgt_h) |
|
tgt = None |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|