#!/usr/bin/env python3 # coding=utf-8 class AbstractParser: def __init__(self, dataset): self.dataset = dataset def create_nodes(self, prediction): return [ {"id": i, "label": self.label_to_str(l, prediction["anchors"][i], prediction)} for i, l in enumerate(prediction["labels"]) ] def label_to_str(self, label, anchors, prediction): return self.dataset.label_field.vocab.itos[label - 1] def create_edges(self, prediction, nodes): N = len(nodes) node_sets = [{"id": n, "set": set([n])} for n in range(N)] _, indices = prediction["edge presence"][:N, :N].reshape(-1).sort(descending=True) sources, targets = indices // N, indices % N edges = [] for i in range((N - 1) * N // 2): source, target = sources[i].item(), targets[i].item() p = prediction["edge presence"][source, target] if p < 0.5 and len(edges) >= N - 1: break if node_sets[source]["set"] is node_sets[target]["set"] and p < 0.5: continue self.create_edge(source, target, prediction, edges, nodes) if node_sets[source]["set"] is not node_sets[target]["set"]: from_set = node_sets[source]["set"] for n in node_sets[target]["set"]: from_set.add(n) node_sets[n]["set"] = from_set return edges def create_edge(self, source, target, prediction, edges, nodes): label = self.get_edge_label(prediction, source, target) edge = {"source": source, "target": target, "label": label} edges.append(edge) def create_anchors(self, prediction, nodes, join_contiguous=True, at_least_one=False, single_anchor=False, mode="anchors"): for i, node in enumerate(nodes): threshold = 0.5 if not at_least_one else min(0.5, prediction[mode][i].max().item()) node[mode] = (prediction[mode][i] >= threshold).nonzero(as_tuple=False).squeeze(-1) node[mode] = prediction["token intervals"][node[mode], :] if single_anchor and len(node[mode]) > 1: start = min(a[0].item() for a in node[mode]) end = max(a[1].item() for a in node[mode]) node[mode] = [{"from": start, "to": end}] continue node[mode] = [{"from": f.item(), "to": t.item()} for f, t in node[mode]] node[mode] = sorted(node[mode], key=lambda a: a["from"]) if join_contiguous and len(node[mode]) > 1: cleaned_anchors = [] end, start = node[mode][0]["from"], node[mode][0]["from"] for anchor in node[mode]: if end < anchor["from"]: cleaned_anchors.append({"from": start, "to": end}) start = anchor["from"] end = anchor["to"] cleaned_anchors.append({"from": start, "to": end}) node[mode] = cleaned_anchors return nodes def get_edge_label(self, prediction, source, target): return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].item()]