import torch import spacy, nltk from nltk.tree import Tree import numpy as np def collapse_unary_strip_pos(tree, strip_top=True): """Collapse unary chains and strip part of speech tags.""" def strip_pos(tree): if len(tree) == 1 and isinstance(tree[0], str): return tree[0] else: return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) collapsed_tree = strip_pos(tree) collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): if strip_top: if len(collapsed_tree) == 1: collapsed_tree = collapsed_tree[0] else: collapsed_tree.set_label("") elif len(collapsed_tree) == 1: collapsed_tree[0].set_label( collapsed_tree.label() + "::" + collapsed_tree[0].label()) collapsed_tree = collapsed_tree[0] return collapsed_tree def _get_labeled_spans(tree, spans_out, start): if isinstance(tree, str): return start + 1 assert len(tree) > 1 or isinstance( tree[0], str ), "Must call collapse_unary_strip_pos first" end = start for child in tree: end = _get_labeled_spans(child, spans_out, end) # Spans are returned as closed intervals on both ends spans_out.append((start, end - 1, tree.label())) return end def get_labeled_spans(tree): """Converts a tree into a list of labeled spans. Args: tree: an nltk.tree.Tree object Returns: A list of (span_start, span_end, span_label) tuples. The start and end indices indicate the first and last words of the span (a closed interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will result in a single span labeled "S+VP". """ tree = collapse_unary_strip_pos(tree) spans_out = [] _get_labeled_spans(tree, spans_out, start=0) return spans_out def padded_chart_from_spans(label_vocab, spans, ): num_words = 64 chart = np.full((num_words, num_words), -100, dtype=int) # chart = np.tril(chart, -1) # Now all invalid entries are filled with -100, and valid entries with 0 for start, end, label in spans: if label in label_vocab: chart[start, end] = label_vocab[label] return chart def chart_from_tree(label_vocab, tree, verbose=False): spans = get_labeled_spans(tree) num_words = len(tree.leaves()) chart = np.full((num_words, num_words), -100, dtype=int) chart = np.tril(chart, -1) # Now all invalid entries are filled with -100, and valid entries with 0 # print(tree) for start, end, label in spans: # Previously unseen unary chains can occur in the dev/test sets. # For now, we ignore them and don't mark the corresponding chart # entry as a constituent. # print(start, end, label) if label in label_vocab: chart[start, end] = label_vocab[label] if not verbose: return chart else: return chart, spans def pad_charts(charts, padding_value=-100): """ Our input text format contains START and END, but the parse charts doesn't. NEED TO: update the charts, so that we include these two, and set their span label to 0. :param charts: :param padding_value: :return: """ max_len = 64 padded_charts = torch.full( (len(charts), max_len, max_len), padding_value, ) padded_charts = np.tril(padded_charts, -1) # print(padded_charts[-2:], padded_charts.shape) # print(padded_charts[1]) for i, chart in enumerate(charts): # print(chart, len(chart), len(chart[0])) chart_size = len(chart) padded_charts[i, 1:chart_size+1, 1:chart_size+1] = chart # print(padded_charts[-2:], padded_charts.shape) return padded_charts