#!/usr/bin/env python # -*- coding: utf-8 -*- # Author: Rico Sennrich """Compute chrF3 for machine translation evaluation Reference: Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal. """ from __future__ import print_function, unicode_literals, division import sys import codecs import io import argparse from collections import defaultdict # hack for python2/3 compatibility from io import open argparse.open = open def create_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="learn BPE-based word segmentation") parser.add_argument( '--ref', '-r', type=argparse.FileType('r'), required=True, metavar='PATH', help="Reference file") parser.add_argument( '--hyp', type=argparse.FileType('r'), metavar='PATH', default=sys.stdin, help="Hypothesis file (default: stdin).") parser.add_argument( '--beta', '-b', type=float, default=3, metavar='FLOAT', help="beta parameter (default: '%(default)s')") parser.add_argument( '--ngram', '-n', type=int, default=6, metavar='INT', help="ngram order (default: '%(default)s')") parser.add_argument( '--space', '-s', action='store_true', help="take spaces into account (default: '%(default)s')") parser.add_argument( '--precision', action='store_true', help="report precision (default: '%(default)s')") parser.add_argument( '--recall', action='store_true', help="report recall (default: '%(default)s')") return parser def extract_ngrams(words, max_length=4, spaces=False): if not spaces: words = ''.join(words.split()) else: words = words.strip() results = defaultdict(lambda: defaultdict(int)) for length in range(max_length): for start_pos in range(len(words)): end_pos = start_pos + length + 1 if end_pos <= len(words): results[length][tuple(words[start_pos: end_pos])] += 1 return results def get_correct(ngrams_ref, ngrams_test, correct, total): for rank in ngrams_test: for chain in ngrams_test[rank]: total[rank] += ngrams_test[rank][chain] if chain in ngrams_ref[rank]: correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain]) return correct, total def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0): precision = 0 recall = 0 for i in range(max_length): if total_hyp[i] + smooth and total_ref[i] + smooth: precision += (correct[i] + smooth) / (total_hyp[i] + smooth) recall += (correct[i] + smooth) / (total_ref[i] + smooth) precision /= max_length recall /= max_length return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall def main(args): correct = [0]*args.ngram total = [0]*args.ngram total_ref = [0]*args.ngram for line in args.ref: line2 = args.hyp.readline() ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space) ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space) get_correct(ngrams_ref, ngrams_test, correct, total) for rank in ngrams_ref: for chain in ngrams_ref[rank]: total_ref[rank] += ngrams_ref[rank][chain] chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta) print('chrF3: {0:.4f}'.format(chrf)) if args.precision: print('chrPrec: {0:.4f}'.format(precision)) if args.recall: print('chrRec: {0:.4f}'.format(recall)) if __name__ == '__main__': # python 2/3 compatibility if sys.version_info < (3, 0): sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) sys.stdin = codecs.getreader('UTF-8')(sys.stdin) else: sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) parser = create_parser() args = parser.parse_args() main(args)