#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse from itertools import zip_longest def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): """Replaces out-of-vocabulary words in source and target text with , where N in is the position of the word in the source sequence. """ def format_unk(pos): return "".format(pos) if target_in is None: target_in = [] for seq_num, (source_seq, target_seq) in enumerate( zip_longest(source_in, target_in) ): source_seq_out = [] target_seq_out = [] word_to_pos = dict() for position, token in enumerate(source_seq.strip().split()): if token in vocabulary: token_out = token else: if token in word_to_pos: oov_pos = word_to_pos[token] else: word_to_pos[token] = position oov_pos = position token_out = format_unk(oov_pos) source_seq_out.append(token_out) source_out.write(" ".join(source_seq_out) + "\n") if target_seq is not None: for token in target_seq.strip().split(): if token in word_to_pos: token_out = format_unk(word_to_pos[token]) else: token_out = token target_seq_out.append(token_out) if target_out is not None: target_out.write(" ".join(target_seq_out) + "\n") def main(): parser = argparse.ArgumentParser( description="Replaces out-of-vocabulary words in both source and target " "sequences with tokens that indicate the position of the word " "in the source sequence." ) parser.add_argument( "--source", type=str, help="text file with source sequences", required=True ) parser.add_argument( "--target", type=str, help="text file with target sequences", default=None ) parser.add_argument("--vocab", type=str, help="vocabulary file", required=True) parser.add_argument( "--source-out", type=str, help="where to write source sequences with entries", required=True, ) parser.add_argument( "--target-out", type=str, help="where to write target sequences with entries", default=None, ) args = parser.parse_args() with open(args.vocab, encoding="utf-8") as vocab: vocabulary = vocab.read().splitlines() target_in = ( open(args.target, "r", encoding="utf-8") if args.target is not None else None ) target_out = ( open(args.target_out, "w", encoding="utf-8") if args.target_out is not None else None ) with open(args.source, "r", encoding="utf-8") as source_in, open( args.source_out, "w", encoding="utf-8" ) as source_out: replace_oovs(source_in, target_in, vocabulary, source_out, target_out) if target_in is not None: target_in.close() if target_out is not None: target_out.close() if __name__ == "__main__": main()