#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of candidate hypotheses. See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" (Shen et al., 2019) `_. """ import argparse import random import sys from itertools import chain import numpy as np from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu def main(): parser = argparse.ArgumentParser(sys.argv[0]) parser.add_argument( "--sys", nargs="*", default="", metavar="FILE", help="path to system output" ) parser.add_argument("--ref", default="", metavar="FILE", help="path to references") parser.add_argument( "--output", default="", metavar="FILE", help="print outputs into a pretty format", ) args = parser.parse_args() if args.sys: src, tgt, hypos, log_probs = load_sys(args.sys) print("pairwise BLEU: %.2f" % pairwise(hypos)) if args.output: merge(src, tgt, hypos, log_probs, args.output) if args.ref: _, _, refs = load_ref(args.ref) if args.sys: multi_ref(refs, hypos) else: intra_ref(refs) def dictolist(d): a = sorted(d.items(), key=lambda i: i[0]) return [i[1] for i in a] def load_sys(paths): src, tgt, hypos, log_probs = {}, {}, {}, {} for path in paths: with open(path) as f: for line in f: line = line.rstrip() # S: source # T: target # D: detokenized system output if line.startswith(("S-", "T-", "D-")): i = int(line[line.find("-") + 1 : line.find("\t")]) if line.startswith("S-"): src[i] = line.split("\t")[1] if line.startswith("T-"): tgt[i] = line.split("\t")[1] if line.startswith("D-"): if i not in hypos: hypos[i] = [] log_probs[i] = [] hypos[i].append(line.split("\t")[2]) log_probs[i].append(float(line.split("\t")[1])) return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) def load_ref(path): with open(path) as f: lines = f.readlines() src, tgt, refs = [], [], [] i = 0 while i < len(lines): if lines[i].startswith("S-"): src.append(lines[i].split("\t")[1].rstrip()) i += 1 elif lines[i].startswith("T-"): tgt.append(lines[i].split("\t")[1].rstrip()) i += 1 else: a = [] while i < len(lines) and lines[i].startswith("R"): a.append(lines[i].split("\t")[1].rstrip()) i += 1 refs.append(a) return src, tgt, refs def merge(src, tgt, hypos, log_probs, path): with open(path, "w") as f: for s, t, hs, lps in zip(src, tgt, hypos, log_probs): f.write(s + "\n") f.write(t + "\n") f.write("\n") for h, lp in zip(hs, lps): f.write("\t%f\t%s\n" % (lp, h.strip())) f.write("------------------------------------------------------\n") def corpus_bleu(sys_stream, ref_streams): bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none") return bleu.score def sentence_bleu(hypothesis, reference): bleu = _corpus_bleu(hypothesis, reference) for i in range(1, 4): bleu.counts[i] += 1 bleu.totals[i] += 1 bleu = compute_bleu( bleu.counts, bleu.totals, bleu.sys_len, bleu.ref_len, smooth_method="exp", ) return bleu.score def pairwise(sents): _ref, _hypo = [], [] for s in sents: for i in range(len(s)): for j in range(len(s)): if i != j: _ref.append(s[i]) _hypo.append(s[j]) return corpus_bleu(_hypo, [_ref]) def multi_ref(refs, hypos): _ref, _hypo = [], [] ref_cnt = 0 assert len(refs) == len(hypos) # count number of refs covered for rs, hs in zip(refs, hypos): a = set() for h in hs: s = [sentence_bleu(h, r) for r in rs] j = np.argmax(s) _ref.append(rs[j]) _hypo.append(h) best = [k for k in range(len(rs)) if s[k] == s[j]] a.add(random.choice(best)) ref_cnt += len(a) print("#refs covered: %.2f" % (ref_cnt / len(refs))) # transpose refs and hypos refs = list(zip(*refs)) hypos = list(zip(*hypos)) # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref) k = len(hypos) m = len(refs) flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)] duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs] loo_bleus = [] for held_out_ref in range(m): remaining_refs = ( duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :] ) assert len(remaining_refs) == m - 1 loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs)) print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus)) def intra_ref(refs): print("ref pairwise BLEU: %.2f" % pairwise(refs)) refs = list(zip(*refs)) m = len(refs) concat_h = [] concat_rest = [[] for j in range(m - 1)] for i, h in enumerate(refs): rest = refs[:i] + refs[i + 1 :] concat_h.append(h) for j in range(m - 1): concat_rest[j].extend(rest[j]) concat_h = list(chain.from_iterable(concat_h)) bleu = corpus_bleu(concat_h, concat_rest) print("multi-reference BLEU (leave-one-out): %.2f" % bleu) if __name__ == "__main__": main()