Spaces:
Build error
Build error
import nltk | |
from collections import Counter | |
from weakly_supervised_parser.tree.evaluate import tree_to_spans | |
class Tree(object): | |
def __init__(self, label, children, word): | |
self.label = label | |
self.children = children | |
self.word = word | |
def __str__(self): | |
return self.linearize() | |
def linearize(self): | |
if not self.children: | |
return f"({self.label} {self.word})" | |
return f"({self.label} {' '.join(c.linearize() for c in self.children)})" | |
def spans(self, start=0): | |
if not self.children: | |
return [(start, start + 1)] | |
span_list = [] | |
position = start | |
for c in self.children: | |
cspans = c.spans(start=position) | |
span_list.extend(cspans) | |
position = cspans[0][1] | |
return [(start, position)] + span_list | |
def spans_labels(self, start=0): | |
if not self.children: | |
return [(start, start + 1, self.label)] | |
span_list = [] | |
position = start | |
for c in self.children: | |
cspans = c.spans_labels(start=position) | |
span_list.extend(cspans) | |
position = cspans[0][1] | |
return [(start, position, self.label)] + span_list | |
def extract_sentence(sentence): | |
t = nltk.Tree.fromstring(sentence) | |
return " ".join(item[0] for item in t.pos()) | |
def get_constituents(sample_string, want_spans_mapping=False, whole_sentence=True, labels=False): | |
t = nltk.Tree.fromstring(sample_string) | |
if want_spans_mapping: | |
spans = tree_to_spans(t, keep_labels=True) | |
return dict(Counter(item[1] for item in spans)) | |
spans = tree_to_spans(t, keep_labels=True) | |
sentence = extract_sentence(sample_string).split() | |
labeled_consituents_lst = [] | |
constituents = [] | |
for span in spans: | |
labeled_consituents = {} | |
labeled_consituents["labels"] = span[0] | |
i, j = span[1][0], span[1][1] | |
constituents.append(" ".join(sentence[i:j])) | |
labeled_consituents["constituent"] = " ".join(sentence[i:j]) | |
labeled_consituents_lst.append(labeled_consituents) | |
# Add original sentence | |
if whole_sentence: | |
constituents = constituents + [" ".join(sentence)] | |
if labels: | |
return labeled_consituents_lst | |
return constituents | |
def get_distituents(sample_string): | |
sentence = extract_sentence(sample_string).split() | |
def get_all_combinations(sentence): | |
L = sentence.split() | |
N = len(L) | |
out = [] | |
for n in range(2, N): | |
for i in range(N - n + 1): | |
out.append((i, i + n)) | |
return out | |
combinations = get_all_combinations(extract_sentence(sample_string)) | |
constituents = list(get_constituents(sample_string, want_spans_mapping=True).keys()) | |
spans = [item for item in combinations if item not in constituents] | |
distituents = [] | |
for span in spans: | |
i, j = span[0], span[1] | |
distituents.append(" ".join(sentence[i:j])) | |
return distituents | |
def get_leaves(tree): | |
if not tree.children: | |
return [tree] | |
leaves = [] | |
for c in tree.children: | |
leaves.extend(get_leaves(c)) | |
return leaves | |
def unlinearize(string): | |
""" | |
(TOP (S (NP (PRP He)) (VP (VBD was) (ADJP (JJ right))) (. .))) | |
""" | |
tokens = string.replace("(", " ( ").replace(")", " ) ").split() | |
def read_tree(start): | |
if tokens[start + 2] != "(": | |
return Tree(tokens[start + 1], None, tokens[start + 2]), start + 4 | |
i = start + 2 | |
children = [] | |
while tokens[i] != ")": | |
tree, i = read_tree(i) | |
children.append(tree) | |
return Tree(tokens[start + 1], children, None), i + 1 | |
tree, _ = read_tree(0) | |
return tree | |
def recall_by_label(gold_standard, best_parse): | |
correct = {} | |
total = {} | |
for tree1, tree2 in zip(gold_standard, best_parse): | |
try: | |
leaves1, leaves2 = get_leaves(tree1["tree"]), get_leaves(tree2["tree"]) | |
for l1, l2 in zip(leaves1, leaves2): | |
assert l1.word.lower() == l2.word.lower(), f"{l1.word} =/= {l2.word}" | |
spanlabels = tree1["tree"].spans_labels() | |
spans = tree2["tree"].spans() | |
for (i, j, label) in spanlabels: | |
if j - i != 1: | |
if label not in correct: | |
correct[label] = 0 | |
total[label] = 0 | |
if (i, j) in spans: | |
correct[label] += 1 | |
total[label] += 1 | |
except Exception as e: | |
print(e) | |
acc = {} | |
for label in total.keys(): | |
acc[label] = correct[label] / total[label] | |
return acc | |
def label_recall_output(gold_standard, best_parse): | |
best_parse_trees = [] | |
gold_standard_trees = [] | |
for t1, t2 in zip(gold_standard, best_parse): | |
gold_standard_trees.append({"tree": unlinearize(t1)}) | |
best_parse_trees.append({"tree": unlinearize(t2)}) | |
dct = recall_by_label(gold_standard=gold_standard_trees, best_parse=best_parse_trees) | |
labels = ["SBAR", "NP", "VP", "PP", "ADJP", "ADVP"] | |
l = [{label: f"{recall * 100:.2f}"} for label, recall in dct.items() if label in labels] | |
df = pd.DataFrame([item.values() for item in l], index=[item.keys() for item in l], columns=["recall"]) | |
df.index = df.index.map(lambda x: list(x)[0]) | |
df_out = df.reindex(labels) | |
return df_out | |
if __name__ == "__main__": | |
import pandas as pd | |
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset | |
from weakly_supervised_parser.settings import PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, PTB_SAVE_TREES_PATH | |
best_parse = PTBDataset(PTB_SAVE_TREES_PATH + "inside_model_predictions.txt").retrieve_all_sentences() | |
gold_standard = PTBDataset(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH).retrieve_all_sentences() | |
print(label_recall_output(gold_standard, best_parse)) | |