Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Copyright 2019 Brian Thompson | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
https://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
import argparse | |
import sys | |
from collections import defaultdict | |
import numpy as np | |
from dp_utils import read_alignments | |
""" | |
Faster implementation of lax and strict precision and recall, based on | |
https://www.aclweb.org/anthology/W11-4624/. | |
""" | |
def _precision(goldalign, testalign): | |
""" | |
Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments | |
""" | |
tpstrict = 0 # true positive strict counter | |
tplax = 0 # true positive lax counter | |
fpstrict = 0 # false positive strict counter | |
fplax = 0 # false positive lax counter | |
# convert to sets, remove alignments empty on both sides | |
testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)]) | |
goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)]) | |
# mappings from source test sentence idxs to | |
# target gold sentence idxs for which the source test sentence | |
# was found in corresponding source gold alignment | |
src_id_to_gold_tgt_ids = defaultdict(set) | |
for gold_src, gold_tgt in goldalign: | |
for gold_src_id in gold_src: | |
for gold_tgt_id in gold_tgt: | |
src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id) | |
for (test_src, test_target) in testalign: | |
if (test_src, test_target) == ((), ()): | |
continue | |
if (test_src, test_target) in goldalign: | |
# strict match | |
tpstrict += 1 | |
tplax += 1 | |
else: | |
# For anything with partial gold/test overlap on the source, | |
# see if there is also partial overlap on the gold/test target | |
# If so, its a lax match | |
target_ids = set() | |
for src_test_id in test_src: | |
for tgt_id in src_id_to_gold_tgt_ids[src_test_id]: | |
target_ids.add(tgt_id) | |
if set(test_target).intersection(target_ids): | |
fpstrict += 1 | |
tplax += 1 | |
else: | |
fpstrict += 1 | |
fplax += 1 | |
return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32) | |
def score_multiple(gold_list, test_list, value_for_div_by_0=0.0): | |
# accumulate counts for all gold/test files | |
pcounts = np.array([0, 0, 0, 0], dtype=np.int32) | |
rcounts = np.array([0, 0, 0, 0], dtype=np.int32) | |
for goldalign, testalign in zip(gold_list, test_list): | |
pcounts += _precision(goldalign=goldalign, testalign=testalign) | |
# recall is precision with no insertion/deletion and swap args | |
test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)] | |
gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)] | |
rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del) | |
# Compute results | |
# pcounts: tpstrict,fnstrict,tplax,fnlax | |
# rcounts: tpstrict,fpstrict,tplax,fplax | |
if pcounts[0] + pcounts[1] == 0: | |
pstrict = value_for_div_by_0 | |
else: | |
pstrict = pcounts[0] / float(pcounts[0] + pcounts[1]) | |
if pcounts[2] + pcounts[3] == 0: | |
plax = value_for_div_by_0 | |
else: | |
plax = pcounts[2] / float(pcounts[2] + pcounts[3]) | |
if rcounts[0] + rcounts[1] == 0: | |
rstrict = value_for_div_by_0 | |
else: | |
rstrict = rcounts[0] / float(rcounts[0] + rcounts[1]) | |
if rcounts[2] + rcounts[3] == 0: | |
rlax = value_for_div_by_0 | |
else: | |
rlax = rcounts[2] / float(rcounts[2] + rcounts[3]) | |
if (pstrict + rstrict) == 0: | |
fstrict = value_for_div_by_0 | |
else: | |
fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict) | |
if (plax + rlax) == 0: | |
flax = value_for_div_by_0 | |
else: | |
flax = 2 * (plax * rlax) / (plax + rlax) | |
result = dict(recall_strict=rstrict, | |
recall_lax=rlax, | |
precision_strict=pstrict, | |
precision_lax=plax, | |
f1_strict=fstrict, | |
f1_lax=flax) | |
return result | |
def log_final_scores(res): | |
print(' ---------------------------------', file=sys.stderr) | |
print('| | Strict | Lax |', file=sys.stderr) | |
print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr) | |
print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr) | |
print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr) | |
print(' ---------------------------------', file=sys.stderr) | |
def main(): | |
parser = argparse.ArgumentParser( | |
'Compute strict/lax precision and recall for one or more pairs of gold/test alignments', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('-t', '--test', type=str, nargs='+', required=True, | |
help='one or more test alignment files') | |
parser.add_argument('-g', '--gold', type=str, nargs='+', required=True, | |
help='one or more gold alignment files') | |
args = parser.parse_args() | |
if len(args.test) != len(args.gold): | |
raise Exception('number of gold/test files must be the same') | |
gold_list = [read_alignments(x) for x in args.gold] | |
test_list = [read_alignments(x) for x in args.test] | |
res = score_multiple(gold_list=gold_list, test_list=test_list) | |
log_final_scores(res) | |
if __name__ == '__main__': | |
main() | |