File size: 4,184 Bytes
1d5604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys;

import score.core;
from smatch.smatch import get_amr_match;

def tuples(graph, prefix, values, faith = True):
  #
  # mimicry of get_triples() in amr.py
  #
  id = 0;
  mapping = dict();
  instances = [];
  relations = [];
  attributes = [];
  n = 0;
  for node in graph.nodes:
    mapping[node.id] = name = prefix + str(id);
    id += 1;
    if "anchors" in values and node.anchors is not None:
      anchor = score.core.anchor(node);
      if graph.input: anchor = score.core.explode(graph.input, anchor);
      attributes.append(("anchor", name, str(anchor)));
    if "labels" in values and node.label is not None:
      instance = node.label;
    else:
      instance = "__()_{}__".format(prefix, n);
      n += 1;
    instances.append(("instance", name, instance));
    if "tops" in values and node.is_top:
      #
      # the native SMATCH code (wrongly, i believe) ties the top property to
      # the node label (see https://github.com/cfmrp/mtool/issues/12).  we get
      # to choose whether to faithfully replicate those scores or not.
      #
      attributes.append(("TOP", name,
                         node.label if node.label and faith else ""));
    if "properties" in values and node.properties and node.values:
      for property, value in zip(node.properties, node.values):
        attributes.append((property, name, value));
  for edge in graph.edges:
    if "edges" in values:
      relations.append((edge.lab, mapping[edge.src], mapping[edge.tgt]));
    if "attributes" in values:
      if edge.attributes and edge.values:
        for attribute, value in zip(edge.attributes, edge.values):
          relations.append((str((attribute, value)),
                            mapping[edge.src], mapping[edge.tgt]));
  return instances, attributes, relations, n;

def smatch(gold, system, limit = 20, values = {}, trace = 0, faith = True):
  gprefix = "g"; sprefix = "s";
  ginstances, gattributes, grelations, gn \
    = tuples(gold, gprefix, values, faith);
  sinstances, sattributes, srelations, sn \
    = tuples(system, sprefix, values, faith);
  if trace > 1:
    print("gold instances [{}]: {}\ngold attributes [{}]: {}\n"
          "gold relations [{}]: {}"
          "".format(len(ginstances), ginstances,
                    len(gattributes), gattributes,
                    len(grelations), grelations),
          file = sys.stderr);
    print("system instances [{}]: {}\nsystem attributes [{}]: {}\n"
          "system relations [{}]: {}"
          "".format(len(sinstances), sinstances,
                    len(sattributes), sattributes,
                    len(srelations), srelations),
          file = sys.stderr);
  correct, gold, system, mapping \
    = get_amr_match(None, None, gold.id, limit = limit,
                    instance1 = ginstances, attributes1 = gattributes,
                    relation1 = grelations, prefix1 = gprefix,
                    instance2 = sinstances, attributes2 = sattributes,
                    relation2 = srelations, prefix2 = sprefix);
  return correct, gold - gn, system - sn, mapping;

def evaluate(golds, systems, format = "json", limit = 20,
             values = {}, trace = 0):
  if limit is None or not limit > 0: limit = 20;
  if trace > 1: print("RRHC limit: {}".format(limit), file = sys.stderr);
  tg = ts = tc = n = 0;
  scores = dict() if trace else None;
  for gold, system in score.core.intersect(golds, systems):
    id = gold.id;
    correct, gold, system, mapping \
      = smatch(gold, system, limit, values, trace);
    tg += gold; ts += system; tc += correct;
    n += 1;
    if trace:
      if id in scores:
        print("smatch.evaluate(): duplicate graph identifier: {}"
              "".format(id), file = sys.stderr);
      scores[id] = {"g": gold, "s": system, "c": correct};
      if trace > 1:
        p, r, f = score.core.fscore(gold, system, correct);
        print("G: {}; S: {}; C: {}; P: {}; R: {}; F: {}"
              "".format(gold, system, correct, p, r, f), file = sys.stderr);
    
  p, r, f = score.core.fscore(tg, ts, tc);
  result = {"n": n, "g": tg, "s": ts, "c": tc, "p": p, "r": r, "f": f};
  if trace: result["scores"] = scores;
  return result;