JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
3.47 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
import os
import subprocess
import sys
import tempfile
from collections import defaultdict
from itertools import combinations
def read_translations(path, n_repeats):
segment_counter = 0
segment_translations = []
translations = defaultdict(list)
for line in open(path):
segment_translations.append(" ".join(line.split()))
if len(segment_translations) == n_repeats:
translations[segment_counter] = segment_translations
segment_translations = []
segment_counter += 1
return translations
def generate_input(translations, n_repeats):
_, ref_path = tempfile.mkstemp()
_, mt_path = tempfile.mkstemp()
ref_fh = open(ref_path, "w")
mt_fh = open(mt_path, "w")
for segid in sorted(translations.keys()):
assert len(translations[segid]) == n_repeats
indexes = combinations(range(n_repeats), 2)
for idx1, idx2 in indexes:
mt_fh.write(translations[segid][idx1].strip() + "\n")
ref_fh.write(translations[segid][idx2].strip() + "\n")
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
return ref_path, mt_path
def run_meteor(ref_path, mt_path, metric_path, lang="en"):
_, out_path = tempfile.mkstemp()
subprocess.call(
[
"java",
"-Xmx2G",
"-jar",
metric_path,
mt_path,
ref_path,
"-p",
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
"-norm",
"-l",
lang,
],
stdout=open(out_path, "w"),
)
os.remove(ref_path)
os.remove(mt_path)
sys.stderr.write("\nSaved Meteor output to %s" % out_path)
return out_path
def read_output(meteor_output_path, n_repeats):
n_combinations = math.factorial(n_repeats) / (
math.factorial(2) * math.factorial(n_repeats - 2)
)
raw_scores = []
average_scores = []
for line in open(meteor_output_path):
if not line.startswith("Segment "):
continue
score = float(line.strip().split("\t")[1])
raw_scores.append(score)
if len(raw_scores) == n_combinations:
average_scores.append(sum(raw_scores) / n_combinations)
raw_scores = []
os.remove(meteor_output_path)
return average_scores
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--infile")
parser.add_argument("-n", "--repeat_times", type=int)
parser.add_argument("-m", "--meteor")
parser.add_argument("-o", "--output")
args = parser.parse_args()
translations = read_translations(args.infile, args.repeat_times)
sys.stderr.write("\nGenerating input for Meteor...")
ref_path, mt_path = generate_input(translations, args.repeat_times)
sys.stderr.write("\nRunning Meteor...")
out_path = run_meteor(ref_path, mt_path, args.meteor)
sys.stderr.write("\nReading output...")
scores = read_output(out_path, args.repeat_times)
sys.stderr.write("\nWriting results...")
with open(args.output, "w") as o:
for scr in scores:
o.write("{}\n".format(scr))
o.close()
if __name__ == "__main__":
main()