import nltk from collections import Counter from weakly_supervised_parser.tree.evaluate import tree_to_spans class Tree(object): def __init__(self, label, children, word): self.label = label self.children = children self.word = word def __str__(self): return self.linearize() def linearize(self): if not self.children: return f"({self.label} {self.word})" return f"({self.label} {' '.join(c.linearize() for c in self.children)})" def spans(self, start=0): if not self.children: return [(start, start + 1)] span_list = [] position = start for c in self.children: cspans = c.spans(start=position) span_list.extend(cspans) position = cspans[0][1] return [(start, position)] + span_list def spans_labels(self, start=0): if not self.children: return [(start, start + 1, self.label)] span_list = [] position = start for c in self.children: cspans = c.spans_labels(start=position) span_list.extend(cspans) position = cspans[0][1] return [(start, position, self.label)] + span_list def extract_sentence(sentence): t = nltk.Tree.fromstring(sentence) return " ".join(item[0] for item in t.pos()) def get_constituents(sample_string, want_spans_mapping=False, whole_sentence=True, labels=False): t = nltk.Tree.fromstring(sample_string) if want_spans_mapping: spans = tree_to_spans(t, keep_labels=True) return dict(Counter(item[1] for item in spans)) spans = tree_to_spans(t, keep_labels=True) sentence = extract_sentence(sample_string).split() labeled_consituents_lst = [] constituents = [] for span in spans: labeled_consituents = {} labeled_consituents["labels"] = span[0] i, j = span[1][0], span[1][1] constituents.append(" ".join(sentence[i:j])) labeled_consituents["constituent"] = " ".join(sentence[i:j]) labeled_consituents_lst.append(labeled_consituents) # Add original sentence if whole_sentence: constituents = constituents + [" ".join(sentence)] if labels: return labeled_consituents_lst return constituents def get_distituents(sample_string): sentence = extract_sentence(sample_string).split() def get_all_combinations(sentence): L = sentence.split() N = len(L) out = [] for n in range(2, N): for i in range(N - n + 1): out.append((i, i + n)) return out combinations = get_all_combinations(extract_sentence(sample_string)) constituents = list(get_constituents(sample_string, want_spans_mapping=True).keys()) spans = [item for item in combinations if item not in constituents] distituents = [] for span in spans: i, j = span[0], span[1] distituents.append(" ".join(sentence[i:j])) return distituents def get_leaves(tree): if not tree.children: return [tree] leaves = [] for c in tree.children: leaves.extend(get_leaves(c)) return leaves def unlinearize(string): """ (TOP (S (NP (PRP He)) (VP (VBD was) (ADJP (JJ right))) (. .))) """ tokens = string.replace("(", " ( ").replace(")", " ) ").split() def read_tree(start): if tokens[start + 2] != "(": return Tree(tokens[start + 1], None, tokens[start + 2]), start + 4 i = start + 2 children = [] while tokens[i] != ")": tree, i = read_tree(i) children.append(tree) return Tree(tokens[start + 1], children, None), i + 1 tree, _ = read_tree(0) return tree def recall_by_label(gold_standard, best_parse): correct = {} total = {} for tree1, tree2 in zip(gold_standard, best_parse): try: leaves1, leaves2 = get_leaves(tree1["tree"]), get_leaves(tree2["tree"]) for l1, l2 in zip(leaves1, leaves2): assert l1.word.lower() == l2.word.lower(), f"{l1.word} =/= {l2.word}" spanlabels = tree1["tree"].spans_labels() spans = tree2["tree"].spans() for (i, j, label) in spanlabels: if j - i != 1: if label not in correct: correct[label] = 0 total[label] = 0 if (i, j) in spans: correct[label] += 1 total[label] += 1 except Exception as e: print(e) acc = {} for label in total.keys(): acc[label] = correct[label] / total[label] return acc def label_recall_output(gold_standard, best_parse): best_parse_trees = [] gold_standard_trees = [] for t1, t2 in zip(gold_standard, best_parse): gold_standard_trees.append({"tree": unlinearize(t1)}) best_parse_trees.append({"tree": unlinearize(t2)}) dct = recall_by_label(gold_standard=gold_standard_trees, best_parse=best_parse_trees) labels = ["SBAR", "NP", "VP", "PP", "ADJP", "ADVP"] l = [{label: f"{recall * 100:.2f}"} for label, recall in dct.items() if label in labels] df = pd.DataFrame([item.values() for item in l], index=[item.keys() for item in l], columns=["recall"]) df.index = df.index.map(lambda x: list(x)[0]) df_out = df.reindex(labels) return df_out if __name__ == "__main__": import pandas as pd from weakly_supervised_parser.utils.prepare_dataset import PTBDataset from weakly_supervised_parser.settings import PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, PTB_SAVE_TREES_PATH best_parse = PTBDataset(PTB_SAVE_TREES_PATH + "inside_model_predictions.txt").retrieve_all_sentences() gold_standard = PTBDataset(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH).retrieve_all_sentences() print(label_recall_output(gold_standard, best_parse))