Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Author: Rico Sennrich | |
"""Compute chrF3 for machine translation evaluation | |
Reference: | |
Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal. | |
""" | |
from __future__ import print_function, unicode_literals, division | |
import sys | |
import codecs | |
import io | |
import argparse | |
from collections import defaultdict | |
# hack for python2/3 compatibility | |
from io import open | |
argparse.open = open | |
def create_parser(): | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
description="learn BPE-based word segmentation") | |
parser.add_argument( | |
'--ref', '-r', type=argparse.FileType('r'), required=True, | |
metavar='PATH', | |
help="Reference file") | |
parser.add_argument( | |
'--hyp', type=argparse.FileType('r'), metavar='PATH', | |
default=sys.stdin, | |
help="Hypothesis file (default: stdin).") | |
parser.add_argument( | |
'--beta', '-b', type=float, default=3, | |
metavar='FLOAT', | |
help="beta parameter (default: '%(default)s')") | |
parser.add_argument( | |
'--ngram', '-n', type=int, default=6, | |
metavar='INT', | |
help="ngram order (default: '%(default)s')") | |
parser.add_argument( | |
'--space', '-s', action='store_true', | |
help="take spaces into account (default: '%(default)s')") | |
parser.add_argument( | |
'--precision', action='store_true', | |
help="report precision (default: '%(default)s')") | |
parser.add_argument( | |
'--recall', action='store_true', | |
help="report recall (default: '%(default)s')") | |
return parser | |
def extract_ngrams(words, max_length=4, spaces=False): | |
if not spaces: | |
words = ''.join(words.split()) | |
else: | |
words = words.strip() | |
results = defaultdict(lambda: defaultdict(int)) | |
for length in range(max_length): | |
for start_pos in range(len(words)): | |
end_pos = start_pos + length + 1 | |
if end_pos <= len(words): | |
results[length][tuple(words[start_pos: end_pos])] += 1 | |
return results | |
def get_correct(ngrams_ref, ngrams_test, correct, total): | |
for rank in ngrams_test: | |
for chain in ngrams_test[rank]: | |
total[rank] += ngrams_test[rank][chain] | |
if chain in ngrams_ref[rank]: | |
correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain]) | |
return correct, total | |
def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0): | |
precision = 0 | |
recall = 0 | |
for i in range(max_length): | |
if total_hyp[i] + smooth and total_ref[i] + smooth: | |
precision += (correct[i] + smooth) / (total_hyp[i] + smooth) | |
recall += (correct[i] + smooth) / (total_ref[i] + smooth) | |
precision /= max_length | |
recall /= max_length | |
return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall | |
def main(args): | |
correct = [0]*args.ngram | |
total = [0]*args.ngram | |
total_ref = [0]*args.ngram | |
for line in args.ref: | |
line2 = args.hyp.readline() | |
ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space) | |
ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space) | |
get_correct(ngrams_ref, ngrams_test, correct, total) | |
for rank in ngrams_ref: | |
for chain in ngrams_ref[rank]: | |
total_ref[rank] += ngrams_ref[rank][chain] | |
chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta) | |
print('chrF3: {0:.4f}'.format(chrf)) | |
if args.precision: | |
print('chrPrec: {0:.4f}'.format(precision)) | |
if args.recall: | |
print('chrRec: {0:.4f}'.format(recall)) | |
if __name__ == '__main__': | |
# python 2/3 compatibility | |
if sys.version_info < (3, 0): | |
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) | |
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) | |
sys.stdin = codecs.getreader('UTF-8')(sys.stdin) | |
else: | |
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') | |
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') | |
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) | |
parser = create_parser() | |
args = parser.parse_args() | |
main(args) | |