Spaces:
Sleeping
Sleeping
| import torch | |
| import spacy, nltk | |
| from nltk.tree import Tree | |
| import numpy as np | |
| def collapse_unary_strip_pos(tree, strip_top=True): | |
| """Collapse unary chains and strip part of speech tags.""" | |
| def strip_pos(tree): | |
| if len(tree) == 1 and isinstance(tree[0], str): | |
| return tree[0] | |
| else: | |
| return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) | |
| collapsed_tree = strip_pos(tree) | |
| collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") | |
| if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): | |
| if strip_top: | |
| if len(collapsed_tree) == 1: | |
| collapsed_tree = collapsed_tree[0] | |
| else: | |
| collapsed_tree.set_label("") | |
| elif len(collapsed_tree) == 1: | |
| collapsed_tree[0].set_label( | |
| collapsed_tree.label() + "::" + collapsed_tree[0].label()) | |
| collapsed_tree = collapsed_tree[0] | |
| return collapsed_tree | |
| def _get_labeled_spans(tree, spans_out, start): | |
| if isinstance(tree, str): | |
| return start + 1 | |
| assert len(tree) > 1 or isinstance( | |
| tree[0], str | |
| ), "Must call collapse_unary_strip_pos first" | |
| end = start | |
| for child in tree: | |
| end = _get_labeled_spans(child, spans_out, end) | |
| # Spans are returned as closed intervals on both ends | |
| spans_out.append((start, end - 1, tree.label())) | |
| return end | |
| def get_labeled_spans(tree): | |
| """Converts a tree into a list of labeled spans. | |
| Args: | |
| tree: an nltk.tree.Tree object | |
| Returns: | |
| A list of (span_start, span_end, span_label) tuples. The start and end | |
| indices indicate the first and last words of the span (a closed | |
| interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will | |
| result in a single span labeled "S+VP". | |
| """ | |
| tree = collapse_unary_strip_pos(tree) | |
| spans_out = [] | |
| _get_labeled_spans(tree, spans_out, start=0) | |
| return spans_out | |
| def padded_chart_from_spans(label_vocab, spans, ): | |
| num_words = 64 | |
| chart = np.full((num_words, num_words), -100, dtype=int) | |
| # chart = np.tril(chart, -1) | |
| # Now all invalid entries are filled with -100, and valid entries with 0 | |
| for start, end, label in spans: | |
| if label in label_vocab: | |
| chart[start, end] = label_vocab[label] | |
| return chart | |
| def chart_from_tree(label_vocab, tree, verbose=False): | |
| spans = get_labeled_spans(tree) | |
| num_words = len(tree.leaves()) | |
| chart = np.full((num_words, num_words), -100, dtype=int) | |
| chart = np.tril(chart, -1) | |
| # Now all invalid entries are filled with -100, and valid entries with 0 | |
| # print(tree) | |
| for start, end, label in spans: | |
| # Previously unseen unary chains can occur in the dev/test sets. | |
| # For now, we ignore them and don't mark the corresponding chart | |
| # entry as a constituent. | |
| # print(start, end, label) | |
| if label in label_vocab: | |
| chart[start, end] = label_vocab[label] | |
| if not verbose: | |
| return chart | |
| else: | |
| return chart, spans | |
| def pad_charts(charts, padding_value=-100): | |
| """ | |
| Our input text format contains START and END, but the parse charts doesn't. | |
| NEED TO: update the charts, so that we include these two, and set their span label to 0. | |
| :param charts: | |
| :param padding_value: | |
| :return: | |
| """ | |
| max_len = 64 | |
| padded_charts = torch.full( | |
| (len(charts), max_len, max_len), | |
| padding_value, | |
| ) | |
| padded_charts = np.tril(padded_charts, -1) | |
| # print(padded_charts[-2:], padded_charts.shape) | |
| # print(padded_charts[1]) | |
| for i, chart in enumerate(charts): | |
| # print(chart, len(chart), len(chart[0])) | |
| chart_size = len(chart) | |
| padded_charts[i, 1:chart_size+1, 1:chart_size+1] = chart | |
| # print(padded_charts[-2:], padded_charts.shape) | |
| return padded_charts |