Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
from data.parser.to_mrp.abstract_parser import AbstractParser | |
class NodeCentricParser(AbstractParser): | |
def parse(self, prediction): | |
output = {} | |
output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()] | |
output["nodes"] = self.create_nodes(prediction) | |
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True) | |
output["edges"] = self.create_edges(prediction, output["nodes"]) | |
return output | |
def create_edge(self, source, target, prediction, edges, nodes): | |
edge = {"source": source, "target": target, "label": None} | |
edges.append(edge) | |
def create_edges(self, prediction, nodes): | |
N = len(nodes) | |
edge_prediction = prediction["edge presence"][:N, :N] | |
targets = [i for i, node in enumerate(nodes) if node["label"] in ["Source", "Target"]] | |
sources = [i for i, node in enumerate(nodes) if node["label"] not in ["Source", "Target"]] | |
edges = [] | |
for target in targets: | |
for source in sources: | |
if edge_prediction[source, target] >= 0.5: | |
self.create_edge(source, target, prediction, edges, nodes) | |
return edges | |