perin / data /parser /from_mrp /labeled_edge_parser.py
Larisa Kolesnichenko
Add the original perin code
1d5604f
#!/usr/bin/env python3
# coding=utf-8
from data.parser.from_mrp.abstract_parser import AbstractParser
import utility.parser_utils as utils
class LabeledEdgeParser(AbstractParser):
def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
assert part == "training" or part == "validation"
path = args.training_data if part == "training" else args.validation_data
self.data = utils.load_dataset(path)
utils.anchor_ids_from_intervals(self.data)
self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
anchor_count, n_node_token_pairs = 0, 0
for sentence_id, sentence in list(self.data.items()):
for edge in sentence["edges"]:
if "label" not in edge:
del self.data[sentence_id]
break
for node, sentence in utils.node_generator(self.data):
node["label"] = "Node"
self.node_counter += 1
utils.create_bert_tokens(self.data, args.encoder)
# create edge vectors
for sentence in self.data.values():
assert sentence["tops"] == [0], sentence
N = len(sentence["nodes"])
edge_count = utils.create_edges(sentence)
self.edge_counter += edge_count
self.no_edge_counter += N * (N - 1) - edge_count
sentence["nodes"] = sentence["nodes"][1:]
N = len(sentence["nodes"])
sentence["anchor edges"] = [N, len(sentence["input"]), []]
sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
sentence["anchored labels"] = [len(sentence["input"]), []]
for i, node in enumerate(sentence["nodes"]):
anchored_labels = []
for anchor in node["anchors"]:
sentence["anchor edges"][-1].append((i, anchor))
anchored_labels.append((anchor, node["label"]))
sentence["anchored labels"][1].append(anchored_labels)
anchor_count += len(node["anchors"])
n_node_token_pairs += len(sentence["input"])
sentence["id"] = [sentence["id"]]
self.anchor_freq = anchor_count / n_node_token_pairs
self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
super(LabeledEdgeParser, self).__init__(fields, self.data, filter_pred)
@staticmethod
def node_similarity_key(node):
return tuple([node["label"]] + node["anchors"])