#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import re import sys class OOVIndexError(IndexError): def __init__(self, pos, source_seq, target_seq): super(OOVIndexError, self).__init__( "A tag in the target sequence refers to a position that is " "outside the source sequence. Most likely there was a mismatch in " "provided source and target sequences. Otherwise this would mean that " "the pointing mechanism somehow attended to a position that is past " "the actual sequence end." ) self.source_pos = pos self.source_seq = source_seq self.target_seq = target_seq def replace_oovs(source_in, target_in, target_out): """Replaces tokens in the target text with the corresponding word in the source text. """ oov_re = re.compile("^$") for source_seq, target_seq in zip(source_in, target_in): target_seq_out = [] pos_to_word = source_seq.strip().split() for token in target_seq.strip().split(): m = oov_re.match(token) if m: pos = int(m.group(1)) if pos >= len(pos_to_word): raise OOVIndexError(pos, source_seq, target_seq) token_out = pos_to_word[pos] else: token_out = token target_seq_out.append(token_out) target_out.write(" ".join(target_seq_out) + "\n") def main(): parser = argparse.ArgumentParser( description="Replaces tokens in target sequences with words from " "the corresponding position in the source sequence." ) parser.add_argument( "--source", type=str, help="text file with source sequences", required=True ) parser.add_argument( "--target", type=str, help="text file with target sequences", required=True ) parser.add_argument( "--target-out", type=str, help="where to write target sequences without " "entries", required=True, ) args = parser.parse_args() target_in = ( open(args.target, "r", encoding="utf-8") if args.target is not None else None ) target_out = ( open(args.target_out, "w", encoding="utf-8") if args.target_out is not None else None ) with open(args.source, "r", encoding="utf-8") as source_in, open( args.target, "r", encoding="utf-8" ) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out: replace_oovs(source_in, target_in, target_out) if __name__ == "__main__": try: main() except OOVIndexError as e: print(e, file=sys.stderr) print("Source sequence:", e.source_seq.strip(), file=sys.stderr) print("Target sequence:", e.target_seq.strip(), file=sys.stderr) print( "Source sequence length:", len(e.source_seq.strip().split()), file=sys.stderr, ) print("The offending tag points to:", e.source_pos) sys.exit(2)