#!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import fileinput from tqdm import tqdm def main(): parser = argparse.ArgumentParser( description=( "Extract back-translations from the stdout of fairseq-generate. " "If there are multiply hypotheses for a source, we only keep the first one. " ) ) parser.add_argument("--output", required=True, help="output prefix") parser.add_argument( "--srclang", required=True, help="source language (extracted from H-* lines)" ) parser.add_argument( "--tgtlang", required=True, help="target language (extracted from S-* lines)" ) parser.add_argument("--minlen", type=int, help="min length filter") parser.add_argument("--maxlen", type=int, help="max length filter") parser.add_argument("--ratio", type=float, help="ratio filter") parser.add_argument("files", nargs="*", help="input files") args = parser.parse_args() def validate(src, tgt): srclen = len(src.split(" ")) if src != "" else 0 tgtlen = len(tgt.split(" ")) if tgt != "" else 0 if ( (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) or ( args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen) ) or ( args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio) ) ): return False return True def safe_index(toks, index, default): try: return toks[index] except IndexError: return default with open(args.output + "." + args.srclang, "w") as src_h, open( args.output + "." + args.tgtlang, "w" ) as tgt_h: for line in tqdm(fileinput.input(args.files)): if line.startswith("S-"): tgt = safe_index(line.rstrip().split("\t"), 1, "") elif line.startswith("H-"): if tgt is not None: src = safe_index(line.rstrip().split("\t"), 2, "") if validate(src, tgt): print(src, file=src_h) print(tgt, file=tgt_h) tgt = None if __name__ == "__main__": main()