nickil's picture
Upload weakly_supervised_parser/utils/populate_chart.py
bbadcd9
import pandas as pd
import numpy as np
import logging
from datasets.utils import set_progress_bar_enabled
from weakly_supervised_parser.utils.prepare_dataset import NGramify
from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside
from weakly_supervised_parser.model.trainer import InsideOutsideStringPredictor
from weakly_supervised_parser.utils.cky_algorithm import get_best_parse
from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH
from weakly_supervised_parser.model.data_module_loader import DataModule
from weakly_supervised_parser.model.span_classifier import LightningModel
# Disable Dataset.map progress bar
set_progress_bar_enabled(False)
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
# ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH)
# ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)]
ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'mightn', 'we', 'american', 'the', 'another', 'until', "aren't", 'when', 'if', 'am', 'over', 'ma', 'as', 'of', 'with', 'even', 'couldn', 'not', "needn't", 'where', 'there', 'isn', 'however', 'my', 'sales', 'here', 'at', 'yours', 'into', 'wouldn', 'officials', 'no', "hasn't", 'to', 'wasn', 'any', 'ours', 'out', 'each', "wasn't", 'is', 'and', 'me', 'off', 'once', "it's", 'they', 'most', 'also', 'through', 'hasn', 'our', 'or', 'after', "weren't", 'about', 'mr.', 'first', 'haven', 'needn', 'have', "isn't", 'now', "didn't", 'on', 'theirs', 'these', 'before', 'there', 'was', 'which', 'those', 'having', 'do', 'most', 'own', 'among', 'because', 'for', "should've", "shan't", 'so', 'being', 'few', 'too', 'to', 'at', 'people', 'her', 'meanwhile', 'both', 'down', 'doesn', 'below', 'mustn', 'an', 'two', 'more', 'japanese', 'ford', "you'd", 'about', 'but', 'doing', 'itself', 've', 'under', 'what', 'again', 'then', 'your', 'himself', 'now', 'against', 'just', 'does', 'net', "couldn't", 'that', 'he', 'revenue', 'because', 'yesterday', 'them', 'i', 'their', 'all', 'under', 'up', "haven't", 'while', "won't", 'it', 'more', 'it', 'ain', 'him', 'still', 'a', 'he', 'despite', 'should', 'during', 'nor', "shouldn't", 'such', "doesn't", 'are', "that'll", 'since', 'yourselves', 'such', 'those', 'after', 'weren', "you're", 'd', 'like', 'did', 'hadn', 'themselves', 'its', 'but', 'been', 's', "don't", 'these', 'they', 'this', 'his', "mightn't", 'moreover', 'how', 'new', 'above', 'ourselves', 'so', 'why', 'between', 'their', 'general', "wouldn't", 'who', 'i', 'in', 'don', 'shan', 'u.s.', 'ibm', 'separately', 'had', 'you', 'federal', 'if', 'our', 'and', 'only', 'y', 'many', 'one', 'no', 'though', 'won', 'last', 'from', 'each', 'traders', 'john', 'further', 'hers', 'both', "you've", "you'll", 'that', 'all', 'its', 'only', 'here', 'according', "mustn't", 'while', 'in', 'what', 'didn', 'when', 'some', 'on', 'can', 'yourself', 'herself', 'than', 'with', 'has', 'she', 'during', 'will', 'of', 'thus', 'you', 'very', 'o', 'investors', 'a', 'ms.', 'japan', 'were', 'the', 'we', 'm', 'as', 'll', 'be', 'by', 'other', 'yet', 'whom', 'some', 'indeed', 'other', "she's", "hadn't", 'by', 'earlier', 'for', 'instead', 'she', 'an', 't', 're', 'his', 'then', 'aren', 'although']
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
ptb_most_common_first_token = "the"
from pytorch_lightning import Trainer
trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1)
class PopulateCKYChart:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = sentence.split()
self.sentence_length = len(sentence.split())
self.span_scores = np.zeros((self.sentence_length + 1, self.sentence_length + 1), dtype=float)
self.all_spans = NGramify(self.sentence).generate_ngrams(single_span=True, whole_span=True)
def compute_scores(self, model, predict_type, scale_axis, predict_batch_size, chunks=128):
inside_strings = []
outside_strings = []
inside_scores = []
outside_scores = []
for span in self.all_spans:
_, inside_string, outside_string = InsideOutside(sentence=self.sentence).create_inside_outside_matrix(span)
inside_strings.append(inside_string)
outside_strings.append(outside_string)
data = pd.DataFrame({"inside_sentence": inside_strings, "outside_sentence": outside_strings, "span": [span[0] for span in self.all_spans]})
if predict_type == "inside":
# if data.shape[0] > chunks:
# data_chunks = np.array_split(data, data.shape[0] // chunks)
# for data_chunk in data_chunks:
# inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
# scale_axis=scale_axis,
# predict_batch_size=predict_batch_size)[:, 1])
# else:
# inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
# scale_axis=scale_axis,
# predict_batch_size=predict_batch_size)[:, 1])
test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None,
test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]])
inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0])
data["inside_scores"] = inside_scores
data.loc[
(data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token))
& (data["inside_sentence"].str.lower().str.split().str.len() == 2)
& (~data["inside_sentence"].str.lower().str.split().str[-1].isin(RuleBasedHeuristic().get_top_tokens())),
"inside_scores",
] = 1
is_upper_or_title = all([item.istitle() or item.isupper() for item in self.sentence.split()])
is_stop = any([item for item in self.sentence.split() if item.lower() in ptb_top_100_common])
flags = is_upper_or_title and not is_stop
data["scores"] = data["inside_scores"]
elif predict_type == "outside":
outside_scores.extend(model.predict_proba(spans=data.rename(columns={"outside_sentence": "sentence"})[["sentence"]],
scale_axis=scale_axis,
predict_batch_size=predict_batch_size)[:, 1])
data["outside_scores"] = outside_scores
flags = False
data["scores"] = data["outside_scores"]
return flags, data
def fill_chart(self, model, predict_type, scale_axis, predict_batch_size, data=None):
if data is None:
flags, data = self.compute_scores(model, predict_type, scale_axis, predict_batch_size)
for span in self.all_spans:
for i in range(0, self.sentence_length):
for j in range(i + 1, self.sentence_length + 1):
if span[0] == (i, j):
self.span_scores[i, j] = data.loc[data["span"] == span[0], "scores"].item()
return flags, self.span_scores, data
def best_parse_tree(self, span_scores):
span_scores_cky_format = span_scores[:-1, 1:]
best_parse = get_best_parse(sentence=[self.sentence_list], spans=span_scores_cky_format)
return best_parse