nickil commited on
Commit
47c0211
1 Parent(s): fee0b96

add initial files

Browse files
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import benepar
3
+ import spacy
4
+ import nltk
5
+
6
+ from huggingface_hub import hf_hub_url, cached_download
7
+
8
+ from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
9
+ from weakly_supervised_parser.inference import Predictor
10
+ from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
11
+
12
+ benepar.download('benepar_en3')
13
+
14
+ nlp = spacy.load("en_core_web_md")
15
+ nlp.add_pipe("benepar", config={"model": "benepar_en3"})
16
+
17
+ inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
18
+ fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.onnx", revision="main")
19
+ inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))
20
+
21
+ # outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
22
+ # outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
23
+
24
+ # inside_outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
25
+ # inside_outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "inside_outside_model.onnx")
26
+
27
+
28
+ def predict(sentence, model):
29
+ gold_standard = list(nlp(sentence).sents)[0]._.parse_string
30
+ if model == "inside":
31
+ best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside", model=inside_model, scale_axis=1, predict_batch_size=128)
32
+ elif model == "outside":
33
+ best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="outside", model=outside_model, scale_axis=1, predict_batch_size=128)
34
+ elif model == "inside-outside":
35
+ best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
36
+ sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
37
+ return gold_standard, best_parse, sentence_f1
38
+
39
+
40
+ iface = gradio.Interface(
41
+ title="Co-training an Unsupervised Constituency Parser with Weak Supervision",
42
+ description="Demo for the repository - [weakly-supervised-parsing](https://github.com/Nickil21/weakly-supervised-parsing) (ACL Findings 2022)",
43
+ theme="default",
44
+ article="""<h4 class='text-lg font-semibold my-2'>Note</h4>
45
+ - We use a strong supervised parsing model `benepar_en3` which is based on T5-small to compute the gold parse.<br>
46
+ - Sentence F1 score corresponds to the macro F1 score.
47
+ """,
48
+ allow_flagging="never",
49
+ fn=predict,
50
+ inputs=[
51
+ gradio.inputs.Textbox(label="Sentence", placeholder="Enter a sentence in English"),
52
+ gradio.inputs.Radio(["inside", "outside", "inside-outside"], default="inside", label="Choose Model"),
53
+ ],
54
+ outputs=[
55
+ gradio.outputs.Textbox(label="Gold Parse Tree"),
56
+ gradio.outputs.Textbox(label="Predicted Parse Tree"),
57
+ gradio.outputs.Textbox(label="F1 score"),
58
+ ],
59
+ examples=[
60
+ ["Russia 's war on Ukraine unsettles investors expecting carve-out deal uptick for 2022 .", "inside-outside"],
61
+ ["Bitcoin community under pressure to cut energy use .", "inside"],
62
+ ],
63
+ )
64
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ spacy==3.1.4
2
+ benepar==0.2.0
weakly_supervised_parser/__init__.py ADDED
File without changes
weakly_supervised_parser/inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from loguru import logger
3
+
4
+ from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
5
+ from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
6
+ from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
7
+ from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
8
+ from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
9
+ from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
10
+
11
+
12
+ class Predictor:
13
+ def __init__(self, sentence):
14
+ self.sentence = sentence
15
+ self.sentence_list = sentence.split()
16
+
17
+ def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
18
+ unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type,
19
+ model=model,
20
+ scale_axis=scale_axis,
21
+ predict_batch_size=predict_batch_size)
22
+
23
+ if unique_tokens_flag:
24
+ best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
25
+ logger.info("BEST PARSE", best_parse)
26
+ else:
27
+ best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
28
+ if return_df:
29
+ return best_parse, df
30
+ return best_parse
31
+
32
+
33
+ def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
34
+ best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type,
35
+ model=model,
36
+ scale_axis=scale_axis,
37
+ predict_batch_size=predict_batch_size,
38
+ return_df=True)
39
+ gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
40
+ sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
41
+ if sentence_f1 < 25.0:
42
+ logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
43
+ else:
44
+ logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
45
+ if return_df:
46
+ return best_parse, df
47
+ else:
48
+ return best_parse
49
+
50
+
51
+ def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
52
+ _, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
53
+ _, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
54
+ df = df_inside.copy()
55
+ df["scores"] = df_inside["scores"] * df_outside["scores"]
56
+ _, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
57
+ best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
58
+ gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
59
+ sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
60
+ if sentence_f1 < 25.0:
61
+ logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
62
+ else:
63
+ logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
64
+ return best_parse
65
+
66
+
67
+ def main():
68
+ parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)
69
+
70
+ group = parser.add_mutually_exclusive_group(required=True)
71
+
72
+ group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")
73
+
74
+ group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")
75
+
76
+ group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")
77
+
78
+ group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")
79
+
80
+ parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")
81
+
82
+ parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
83
+
84
+ parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
85
+
86
+ parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")
87
+
88
+ parser.add_argument(
89
+ "--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
94
+ )
95
+
96
+ args = parser.parse_args()
97
+
98
+ if args.use_inside:
99
+ pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.onnx"
100
+ max_seq_length = args.inside_max_seq_length
101
+
102
+ if args.use_inside_self_train:
103
+ pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
104
+ max_seq_length = args.inside_max_seq_length
105
+
106
+ if args.use_outside:
107
+ pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
108
+ max_seq_length = args.outside_max_seq_length
109
+
110
+ if args.use_inside_outside_co_train:
111
+ inside_pre_trained_model_path = "inside_model_co_trained.onnx"
112
+ inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
113
+ inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)
114
+
115
+ outside_pre_trained_model_path = "outside_model_co_trained.onnx"
116
+ outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
117
+ outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
118
+ else:
119
+ model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
120
+ model.load_model(pre_trained_model_path=pre_trained_model_path)
121
+
122
+ if args.use_inside or args.use_inside_self_train:
123
+ predict_type = "inside"
124
+
125
+ if args.use_outside:
126
+ predict_type = "outside"
127
+
128
+ with open(args.save_path, "w") as out_file:
129
+ print(type(args.scale_axis))
130
+ test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
131
+ test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
132
+ for test_index, test_sentence in enumerate(test_sentences):
133
+ if args.use_inside_outside_co_train:
134
+ best_parse = process_co_train_test_sample(
135
+ test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
136
+ )
137
+ else:
138
+ best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
139
+ scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)
140
+
141
+ out_file.write(best_parse + "\n")
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()
weakly_supervised_parser/model/__init__.py ADDED
File without changes
weakly_supervised_parser/model/data_module_loader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from transformers import AutoTokenizer
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from pytorch_lightning import LightningDataModule
6
+
7
+
8
+ class PyTorchDataModule(Dataset):
9
+ """PyTorch Dataset class"""
10
+
11
+ def __init__(self, model_name_or_path: str, data: pd.DataFrame, max_seq_length: int = 256):
12
+ """
13
+ Initiates a PyTorch Dataset Module for input data
14
+ """
15
+ self.model_name_or_path = model_name_or_path
16
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
17
+ self.data = data
18
+ self.max_seq_length = max_seq_length
19
+
20
+ def __len__(self):
21
+ """returns length of data"""
22
+ return len(self.data)
23
+
24
+ def __getitem__(self, index: int):
25
+ """returns dictionary of input tensors to feed into the model"""
26
+
27
+ data_row = self.data.iloc[index]
28
+ sentence = data_row["sentence"]
29
+
30
+ sentence_encoding = self.tokenizer(
31
+ sentence,
32
+ max_length=self.max_seq_length,
33
+ padding="max_length",
34
+ truncation=True,
35
+ add_special_tokens=True,
36
+ return_tensors="pt",
37
+ )
38
+
39
+ out = dict(
40
+ sentence=sentence,
41
+ input_ids=sentence_encoding["input_ids"].flatten(),
42
+ attention_mask=sentence_encoding["attention_mask"].flatten(),
43
+ labels=data_row["label"].flatten(),
44
+ )
45
+
46
+ return out
47
+
48
+
49
+ class DataModule(LightningDataModule):
50
+ def __init__(
51
+ self,
52
+ model_name_or_path: str,
53
+ train_df: pd.DataFrame,
54
+ eval_df: pd.DataFrame,
55
+ max_seq_length: int = 256,
56
+ train_batch_size: int = 32,
57
+ eval_batch_size: int = 32,
58
+ num_workers: int = 16,
59
+ **kwargs
60
+ ):
61
+ super().__init__()
62
+ self.model_name_or_path = model_name_or_path
63
+ self.train_df = train_df
64
+ self.eval_df = eval_df
65
+ self.max_seq_length = max_seq_length
66
+ self.train_batch_size = train_batch_size
67
+ self.eval_batch_size = eval_batch_size
68
+ self.num_workers = num_workers
69
+
70
+ def setup(self, stage=None):
71
+
72
+ self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length)
73
+ self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length)
74
+
75
+ def train_dataloader(self) -> DataLoader:
76
+ return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
77
+
78
+ def val_dataloader(self) -> DataLoader:
79
+ return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
weakly_supervised_parser/model/span_classifier.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchmetrics
3
+ from torch.optim import AdamW
4
+ from pytorch_lightning import LightningModule
5
+ from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
6
+
7
+
8
+ class LightningModel(LightningModule):
9
+ def __init__(
10
+ self,
11
+ model_name_or_path: str,
12
+ num_labels: int = 2,
13
+ lr: float = 5e-6,
14
+ train_batch_size: int = 32,
15
+ adam_epsilon=1e-8,
16
+ warmup_steps: int = 0,
17
+ weight_decay: float = 0.0,
18
+ **kwargs
19
+ ):
20
+ super().__init__()
21
+
22
+ self.save_hyperparameters()
23
+
24
+ self.num_labels = num_labels
25
+ self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels)
26
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
27
+ self.model.gradient_checkpointing_enable()
28
+ self.lr = lr
29
+ self.train_batch_size = train_batch_size
30
+ self.accuracy = torchmetrics.Accuracy()
31
+ self.f1score = torchmetrics.F1Score(num_classes=2)
32
+ self.mcc = torchmetrics.MatthewsCorrCoef(num_classes=2)
33
+
34
+ def forward(self, input_ids, attention_mask, labels=None):
35
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
36
+
37
+ def training_step(self, batch, batch_idx):
38
+ outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
39
+ loss = outputs[0]
40
+ return loss
41
+
42
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
43
+ outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
44
+ val_loss, logits = outputs[:2]
45
+ preds = torch.argmax(logits, axis=1)
46
+ labels = batch["labels"]
47
+ return {"loss": val_loss, "preds": preds, "labels": labels}
48
+
49
+ def validation_epoch_end(self, outputs):
50
+ preds = torch.cat([x["preds"] for x in outputs])
51
+ labels = torch.cat([x["labels"] for x in outputs])
52
+ loss = torch.stack([x["loss"] for x in outputs]).mean()
53
+
54
+ self.log("val_loss", loss, prog_bar=True)
55
+ self.log("val_accuracy", self.accuracy(preds, labels.squeeze()), prog_bar=True)
56
+ self.log("val_f1", self.f1score(preds, labels.squeeze()), prog_bar=True)
57
+ self.log("val_mcc", self.mcc(preds, labels.squeeze()), prog_bar=True)
58
+ return loss
59
+
60
+ def setup(self, stage=None):
61
+ # Get dataloader by calling it - train_dataloader() is called after setup() by default
62
+ train_loader = self.trainer.datamodule.train_dataloader()
63
+
64
+ # Calculate total steps
65
+ tb_size = self.train_batch_size * max(1, self.trainer.gpus)
66
+ ab_size = tb_size * self.trainer.accumulate_grad_batches
67
+ self.total_steps = int((len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs))
68
+
69
+ def configure_optimizers(self):
70
+ """Prepare optimizer and schedule (linear warmup and decay)"""
71
+ model = self.model
72
+ no_decay = ["bias", "LayerNorm.weight"]
73
+ optimizer_grouped_parameters = [
74
+ {
75
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
76
+ "weight_decay": self.hparams.weight_decay,
77
+ },
78
+ {
79
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
80
+ "weight_decay": 0.0,
81
+ },
82
+ ]
83
+ optimizer = AdamW(
84
+ optimizer_grouped_parameters,
85
+ lr=self.lr,
86
+ eps=self.hparams.adam_epsilon,
87
+ )
88
+
89
+ scheduler = get_linear_schedule_with_warmup(
90
+ optimizer,
91
+ num_warmup_steps=self.hparams.warmup_steps,
92
+ num_training_steps=self.total_steps,
93
+ )
94
+ scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
95
+ return [optimizer], [scheduler]
weakly_supervised_parser/model/trainer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import datasets
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+ from pytorch_lightning import Trainer, seed_everything
10
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
11
+ from transformers import AutoTokenizer, logging
12
+
13
+ from onnxruntime import InferenceSession
14
+ from scipy.special import softmax
15
+
16
+ from weakly_supervised_parser.model.data_module_loader import DataModule
17
+ from weakly_supervised_parser.model.span_classifier import LightningModel
18
+
19
+
20
+ # Disable model checkpoint warnings
21
+ logging.set_verbosity_error()
22
+
23
+
24
+ class InsideOutsideStringClassifier:
25
+ def __init__(self, model_name_or_path: str, num_labels: int = 2, max_seq_length: int = 256):
26
+
27
+ self.model_name_or_path = model_name_or_path
28
+ self.num_labels = num_labels
29
+ self.max_seq_length = max_seq_length
30
+
31
+ def fit(
32
+ self,
33
+ train_df: pd.DataFrame,
34
+ eval_df: pd.DataFrame,
35
+ outputdir: str,
36
+ filename: str,
37
+ devices: int = 1,
38
+ enable_progress_bar: bool = True,
39
+ enable_model_summary: bool = False,
40
+ enable_checkpointing: bool = False,
41
+ logger: bool = False,
42
+ accelerator: str = "auto",
43
+ train_batch_size: int = 32,
44
+ eval_batch_size: int = 32,
45
+ learning_rate: float = 5e-6,
46
+ max_epochs: int = 10,
47
+ dataloader_num_workers: int = 16,
48
+ seed: int = 42,
49
+ ):
50
+
51
+ data_module = DataModule(
52
+ model_name_or_path=self.model_name_or_path,
53
+ train_df=train_df,
54
+ eval_df=eval_df,
55
+ max_seq_length=self.max_seq_length,
56
+ train_batch_size=train_batch_size,
57
+ eval_batch_size=eval_batch_size,
58
+ num_workers=dataloader_num_workers,
59
+ )
60
+
61
+ model = LightningModel(
62
+ model_name_or_path=self.model_name_or_path,
63
+ lr=learning_rate,
64
+ num_labels=self.num_labels,
65
+ train_batch_size=train_batch_size,
66
+ eval_batch_size=eval_batch_size,
67
+ )
68
+
69
+ seed_everything(seed, workers=True)
70
+
71
+ callbacks = []
72
+ callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
73
+ # callbacks.append(ModelCheckpoint(monitor="val_loss", dirpath=outputdir, filename=filename, save_top_k=1, save_weights_only=True, mode="min"))
74
+
75
+ trainer = Trainer(
76
+ accelerator=accelerator,
77
+ devices=devices,
78
+ max_epochs=max_epochs,
79
+ callbacks=callbacks,
80
+ enable_progress_bar=enable_progress_bar,
81
+ enable_model_summary=enable_model_summary,
82
+ enable_checkpointing=enable_checkpointing,
83
+ logger=logger,
84
+ )
85
+ trainer.fit(model, data_module)
86
+ trainer.validate(model, data_module.val_dataloader())
87
+
88
+ train_batch = next(iter(data_module.train_dataloader()))
89
+
90
+ model.to_onnx(
91
+ file_path=f"{outputdir}/{filename}.onnx",
92
+ input_sample=(train_batch["input_ids"].cuda(), train_batch["attention_mask"].cuda()),
93
+ export_params=True,
94
+ opset_version=11,
95
+ input_names=["input", "attention_mask"],
96
+ output_names=["output"],
97
+ dynamic_axes={"input": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
98
+ )
99
+
100
+ def load_model(self, pre_trained_model_path):
101
+ self.model = InferenceSession(pre_trained_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
102
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
103
+
104
+ def preprocess_function(self, data):
105
+ features = self.tokenizer(
106
+ data["sentence"], max_length=self.max_seq_length, padding="max_length", add_special_tokens=True, truncation=True, return_tensors="np"
107
+ )
108
+ return features
109
+
110
+ def process_spans(self, spans, scale_axis):
111
+ spans_dataset = datasets.Dataset.from_pandas(spans)
112
+ processed = spans_dataset.map(self.preprocess_function, batched=True, batch_size=None)
113
+ inputs = {"input": processed["input_ids"], "attention_mask": processed["attention_mask"]}
114
+ with torch.no_grad():
115
+ return softmax(self.model.run(None, inputs)[0], axis=scale_axis)
116
+
117
+ def predict_proba(self, spans, scale_axis, predict_batch_size):
118
+ if spans.shape[0] > predict_batch_size:
119
+ output = []
120
+ span_batches = np.array_split(spans, spans.shape[0] // predict_batch_size)
121
+ for span_batch in span_batches:
122
+ output.extend(self.process_spans(span_batch, scale_axis))
123
+ return np.vstack(output)
124
+ else:
125
+ return self.process_spans(spans, scale_axis)
126
+
127
+ def predict(self, spans):
128
+ return self.predict_proba(spans).argmax(axis=1)
weakly_supervised_parser/settings.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR = "weakly_supervised_parser/"
2
+ PTB_TREES_ROOT_DIR = "data/PROCESSED/english/trees/"
3
+ PTB_SENTENCES_ROOT_DIR = "data/PROCESSED/english/sentences/"
4
+
5
+ PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-train-sentences-with-punctuation.txt"
6
+ PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-valid-sentences-with-punctuation.txt"
7
+ PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-test-sentences-with-punctuation.txt"
8
+
9
+ PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-train-sentences-without-punctuation.txt"
10
+ PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-valid-sentences-without-punctuation.txt"
11
+ PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-test-sentences-without-punctuation.txt"
12
+
13
+ PTB_TRAIN_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-with-punctuation.txt"
14
+ PTB_VALID_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-with-punctuation.txt"
15
+ PTB_TEST_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-with-punctuation.txt"
16
+
17
+ PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-without-punctuation.txt"
18
+ PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-without-punctuation.txt"
19
+ PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-without-punctuation.txt"
20
+
21
+ PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-without-punctuation-aligned.txt"
22
+ PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-without-punctuation-aligned.txt"
23
+ PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-without-punctuation-aligned.txt"
24
+
25
+ YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-train-gold-filtered.txt"
26
+ YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-valid-gold-filtered.txt"
27
+ YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-test-gold-filtered.txt"
28
+
29
+ # Predictions
30
+ PTB_SAVE_TREES_PATH = "TEMP/predictions/english/"
31
+
32
+ # Training
33
+ TRAINED_MODEL_PATH = PROJECT_DIR + "/model/TRAINED_MODEL/"
weakly_supervised_parser/tree/__init__.py ADDED
File without changes
weakly_supervised_parser/tree/evaluate.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import os
4
+ import subprocess
5
+
6
+ import nltk
7
+
8
+
9
+ def tree_to_spans(tree, keep_labels=False, keep_leaves=False, keep_whole_span=False):
10
+ if isinstance(tree, str):
11
+ tree = nltk.Tree.fromstring(tree)
12
+
13
+ length = len(tree.pos())
14
+ queue = collections.deque(tree.treepositions())
15
+ stack = [(queue.popleft(), 0)]
16
+ j = 0
17
+ spans = []
18
+ while stack != []:
19
+ (p, i) = stack[-1]
20
+ if not queue or queue[0][:-1] != p:
21
+ if isinstance(tree[p], nltk.tree.Tree):
22
+ if j - i > 1:
23
+ spans.append((tree[p].label(), (i, j)))
24
+ else:
25
+ j = i + 1
26
+ stack.pop()
27
+ else:
28
+ q = queue.popleft()
29
+ stack.append((q, j))
30
+ if not keep_whole_span:
31
+ spans = [span for span in spans if span[1] != (0, length)]
32
+ if not keep_labels:
33
+ spans = [span[1] for span in spans]
34
+ return spans
35
+
36
+
37
+ def test_tree_to_spans():
38
+ assert [(0, 2), (0, 3), (0, 4)] == tree_to_spans("(S (S (S (S (S 1) (S 2)) (S 3)) (S 4)) (S 5))", keep_labels=False)
39
+ assert [] == tree_to_spans("(S 1)", keep_labels=False)
40
+ assert [] == tree_to_spans("(S (S 1) (S 2))", keep_labels=False)
41
+ assert [(1, 3)] == tree_to_spans("(S (S 1) (S (S 2) (S 3)))", keep_labels=False)
42
+ assert [("S", (1, 3))] == tree_to_spans("(S (S 1) (S (S 2) (S 3)))", keep_labels=True)
43
+
44
+
45
+ def get_F1_score_intermediates(gold_spans, pred_spans):
46
+ """Get intermediate results for calculating the F1 score"""
47
+ n_true_positives = 0
48
+ gold_span_counter = collections.Counter(gold_spans)
49
+ pred_span_counter = collections.Counter(pred_spans)
50
+ unique_spans = set(gold_spans + pred_spans)
51
+ for span in unique_spans:
52
+ n_true_positives += min(gold_span_counter[span], pred_span_counter[span])
53
+ return n_true_positives, len(gold_spans), len(pred_spans)
54
+
55
+
56
+ def calculate_F1_score_from_intermediates(n_true_positives, n_golds, n_predictions, precision_recall_f_score=False):
57
+ """Calculate F1 score"""
58
+ if precision_recall_f_score:
59
+ zeros = (0, 0, 0)
60
+ else:
61
+ zeros = 0
62
+ if n_golds == 0:
63
+ return 100 if n_predictions == 0 else zeros
64
+ if n_true_positives == 0 or n_predictions == 0:
65
+ return zeros
66
+ recall = n_true_positives / n_golds
67
+ precision = n_true_positives / n_predictions
68
+ F1 = 2 * precision * recall / (precision + recall)
69
+ if precision_recall_f_score:
70
+ return precision, recall, F1 * 100
71
+ return F1 * 100
72
+
73
+
74
+ def calculate_F1_for_spans(gold_spans, pred_spans, precision_recall_f_score=False):
75
+ # CHANGE THIS LATER
76
+ # gold_spans = list(set(gold_spans))
77
+ ###################################
78
+ tp, n_gold, n_pred = get_F1_score_intermediates(gold_spans, pred_spans)
79
+ if precision_recall_f_score:
80
+ p, r, F1 = calculate_F1_score_from_intermediates(tp, len(gold_spans), len(pred_spans), precision_recall_f_score=precision_recall_f_score)
81
+ return p, r, F1
82
+ F1 = calculate_F1_score_from_intermediates(tp, len(gold_spans), len(pred_spans))
83
+ return F1
84
+
85
+
86
+ def test_calculate_F1_for_spans():
87
+ pred = [(0, 1)]
88
+ gold = [(0, 1)]
89
+ assert calculate_F1_for_spans(gold, pred) == 100
90
+ pred = [(0, 0)]
91
+ gold = [(0, 1)]
92
+ assert calculate_F1_for_spans(gold, pred) == 0
93
+ pred = [(0, 0), (0, 1)]
94
+ gold = [(0, 1), (1, 1)]
95
+ assert calculate_F1_for_spans(gold, pred) == 50
96
+ pred = [(0, 0), (0, 0)]
97
+ gold = [(0, 0), (0, 0), (0, 1)]
98
+ assert calculate_F1_for_spans(gold, pred) == 80
99
+ pred = [(0, 0), (1, 0)]
100
+ gold = [(0, 0), (0, 0), (0, 1)]
101
+ assert calculate_F1_for_spans(gold, pred) == 40
102
+
103
+
104
+ def read_lines_from_file(filepath, len_limit):
105
+ with open(filepath, "r") as f:
106
+ for line in f:
107
+ tree = nltk.Tree.fromstring(line)
108
+ if len_limit is not None and len(tree.pos()) > len_limit:
109
+ continue
110
+ yield line.strip()
111
+
112
+
113
+ def read_spans_from_file(filepath, len_limit):
114
+ for line in read_lines_from_file(filepath, len_limit):
115
+ yield tree_to_spans(line, keep_labels=False, keep_leaves=False, keep_whole_span=False)
116
+
117
+
118
+ def calculate_corpus_level_F1_for_spans(gold_list, pred_list):
119
+ n_true_positives = 0
120
+ n_golds = 0
121
+ n_predictions = 0
122
+ for gold_spans, pred_spans in zip(gold_list, pred_list):
123
+ n_tp, n_g, n_p = get_F1_score_intermediates(gold_spans, pred_spans)
124
+ n_true_positives += n_tp
125
+ n_golds += n_g
126
+ n_predictions += n_p
127
+ F1 = calculate_F1_score_from_intermediates(n_true_positives, n_golds, n_predictions)
128
+ return F1
129
+
130
+
131
+ def calculate_sentence_level_F1_for_spans(gold_list, pred_list):
132
+ f1_scores = []
133
+ for gold_spans, pred_spans in zip(gold_list, pred_list):
134
+ f1 = calculate_F1_for_spans(gold_spans, pred_spans)
135
+ f1_scores.append(f1)
136
+ F1 = sum(f1_scores) / len(f1_scores)
137
+ return F1
138
+
139
+
140
+ def parse_evalb_results_from_file(filepath):
141
+ i_th_score = 0
142
+ score_of_all_length = None
143
+ score_of_length_10 = None
144
+ prefix_of_the_score_line = "Bracketing FMeasure ="
145
+
146
+ with open(filepath, "r") as f:
147
+ for line in f:
148
+ if line.startswith(prefix_of_the_score_line):
149
+ i_th_score += 1
150
+ if i_th_score == 1:
151
+ score_of_all_length = float(line.split()[-1])
152
+ elif i_th_score == 2:
153
+ score_of_length_10 = float(line.split()[-1])
154
+ else:
155
+ raise ValueError("Too many lines for F score")
156
+ return score_of_all_length, score_of_length_10
157
+
158
+
159
+ def execute_evalb(gold_file, pred_file, out_file, len_limit):
160
+ EVALB_PATH = "model/EVALB/"
161
+ subprocess.run("{} -p {} {} {} > {}".format(EVALB_PATH + "/evalb", EVALB_PATH + "unlabelled.prm", gold_file, pred_file, out_file), shell=True)
162
+
163
+
164
+ def calculate_evalb_F1_for_file(gold_file, pred_file, len_limit):
165
+ evalb_out_file = pred_file + ".evalb_out"
166
+ execute_evalb(gold_file, pred_file, evalb_out_file, len_limit)
167
+ F1_len_all, F1_len_10 = parse_evalb_results_from_file(evalb_out_file)
168
+ if len_limit is None:
169
+ return F1_len_all
170
+ elif len_limit == 10:
171
+ return F1_len_10
172
+ else:
173
+ raise ValueError(f"Unexpected len_limit: {len_limit}")
174
+
175
+
176
+ def calculate_sentence_level_F1_for_file(gold_file, pred_file, len_limit):
177
+ gold_list = list(read_spans_from_file(gold_file, len_limit))
178
+ pred_list = list(read_spans_from_file(pred_file, len_limit))
179
+ F1 = calculate_sentence_level_F1_for_spans(gold_list, pred_list)
180
+ return F1
181
+
182
+
183
+ def calculate_corpus_level_F1_for_file(gold_file, pred_file, len_limit):
184
+ gold_list = list(read_spans_from_file(gold_file, len_limit))
185
+ pred_list = list(read_spans_from_file(pred_file, len_limit))
186
+ F1 = calculate_corpus_level_F1_for_spans(gold_list, pred_list)
187
+ return F1
188
+
189
+
190
+ def evaluate_prediction_file(gold_file, pred_file, len_limit):
191
+ corpus_F1 = calculate_corpus_level_F1_for_file(gold_file, pred_file, len_limit)
192
+ sentence_F1 = calculate_sentence_level_F1_for_file(gold_file, pred_file, len_limit)
193
+ # evalb_F1 = calculate_evalb_F1_for_file(gold_file, pred_file, len_limit)
194
+
195
+ print("=====> Evaluation Results <=====")
196
+ print(f"Length constraint: f{len_limit}")
197
+ print(f"Micro F1: {corpus_F1:.2f}, Macro F1: {sentence_F1:.2f}") # , evalb_F1))
198
+ print("=====> Evaluation Results <=====")
199
+
200
+
201
+ def parse_args():
202
+ parser = argparse.ArgumentParser()
203
+ parser.add_argument("--gold_file", "-g", help="path to gold file")
204
+ parser.add_argument("--pred_file", "-p", help="path to prediction file")
205
+ parser.add_argument(
206
+ "--len_limit", default=None, type=int, choices=(None, 10, 20, 30, 40, 50, 100), help="length constraint for evaluation, 10 or None"
207
+ )
208
+ args = parser.parse_args()
209
+
210
+ return args
211
+
212
+
213
+ def main():
214
+ args = parse_args()
215
+ evaluate_prediction_file(args.gold_file, args.pred_file, args.len_limit)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
220
+
221
+ # python helper/evaluate.py -g TEMP/preprocessed_dev.txt -p TEMP/pred_dev_m_None.txt
weakly_supervised_parser/tree/helpers.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from collections import Counter
3
+ from weakly_supervised_parser.tree.evaluate import tree_to_spans
4
+
5
+
6
+ class Tree(object):
7
+ def __init__(self, label, children, word):
8
+ self.label = label
9
+ self.children = children
10
+ self.word = word
11
+
12
+ def __str__(self):
13
+ return self.linearize()
14
+
15
+ def linearize(self):
16
+ if not self.children:
17
+ return f"({self.label} {self.word})"
18
+ return f"({self.label} {' '.join(c.linearize() for c in self.children)})"
19
+
20
+ def spans(self, start=0):
21
+ if not self.children:
22
+ return [(start, start + 1)]
23
+ span_list = []
24
+ position = start
25
+ for c in self.children:
26
+ cspans = c.spans(start=position)
27
+ span_list.extend(cspans)
28
+ position = cspans[0][1]
29
+ return [(start, position)] + span_list
30
+
31
+ def spans_labels(self, start=0):
32
+ if not self.children:
33
+ return [(start, start + 1, self.label)]
34
+ span_list = []
35
+ position = start
36
+ for c in self.children:
37
+ cspans = c.spans_labels(start=position)
38
+ span_list.extend(cspans)
39
+ position = cspans[0][1]
40
+ return [(start, position, self.label)] + span_list
41
+
42
+
43
+ def extract_sentence(sentence):
44
+ t = nltk.Tree.fromstring(sentence)
45
+ return " ".join(item[0] for item in t.pos())
46
+
47
+
48
+ def get_constituents(sample_string, want_spans_mapping=False, whole_sentence=True, labels=False):
49
+ t = nltk.Tree.fromstring(sample_string)
50
+ if want_spans_mapping:
51
+ spans = tree_to_spans(t, keep_labels=True)
52
+ return dict(Counter(item[1] for item in spans))
53
+ spans = tree_to_spans(t, keep_labels=True)
54
+ sentence = extract_sentence(sample_string).split()
55
+
56
+ labeled_consituents_lst = []
57
+ constituents = []
58
+ for span in spans:
59
+ labeled_consituents = {}
60
+ labeled_consituents["labels"] = span[0]
61
+ i, j = span[1][0], span[1][1]
62
+ constituents.append(" ".join(sentence[i:j]))
63
+ labeled_consituents["constituent"] = " ".join(sentence[i:j])
64
+ labeled_consituents_lst.append(labeled_consituents)
65
+
66
+ # Add original sentence
67
+ if whole_sentence:
68
+ constituents = constituents + [" ".join(sentence)]
69
+
70
+ if labels:
71
+ return labeled_consituents_lst
72
+
73
+ return constituents
74
+
75
+
76
+ def get_distituents(sample_string):
77
+ sentence = extract_sentence(sample_string).split()
78
+
79
+ def get_all_combinations(sentence):
80
+ L = sentence.split()
81
+ N = len(L)
82
+ out = []
83
+ for n in range(2, N):
84
+ for i in range(N - n + 1):
85
+ out.append((i, i + n))
86
+ return out
87
+
88
+ combinations = get_all_combinations(extract_sentence(sample_string))
89
+ constituents = list(get_constituents(sample_string, want_spans_mapping=True).keys())
90
+ spans = [item for item in combinations if item not in constituents]
91
+ distituents = []
92
+ for span in spans:
93
+ i, j = span[0], span[1]
94
+ distituents.append(" ".join(sentence[i:j]))
95
+ return distituents
96
+
97
+
98
+ def get_leaves(tree):
99
+ if not tree.children:
100
+ return [tree]
101
+ leaves = []
102
+ for c in tree.children:
103
+ leaves.extend(get_leaves(c))
104
+ return leaves
105
+
106
+
107
+ def unlinearize(string):
108
+ """
109
+ (TOP (S (NP (PRP He)) (VP (VBD was) (ADJP (JJ right))) (. .)))
110
+ """
111
+ tokens = string.replace("(", " ( ").replace(")", " ) ").split()
112
+
113
+ def read_tree(start):
114
+ if tokens[start + 2] != "(":
115
+ return Tree(tokens[start + 1], None, tokens[start + 2]), start + 4
116
+ i = start + 2
117
+ children = []
118
+ while tokens[i] != ")":
119
+ tree, i = read_tree(i)
120
+ children.append(tree)
121
+ return Tree(tokens[start + 1], children, None), i + 1
122
+
123
+ tree, _ = read_tree(0)
124
+ return tree
125
+
126
+
127
+ def recall_by_label(gold_standard, best_parse):
128
+ correct = {}
129
+ total = {}
130
+ for tree1, tree2 in zip(gold_standard, best_parse):
131
+ try:
132
+ leaves1, leaves2 = get_leaves(tree1["tree"]), get_leaves(tree2["tree"])
133
+ for l1, l2 in zip(leaves1, leaves2):
134
+ assert l1.word.lower() == l2.word.lower(), f"{l1.word} =/= {l2.word}"
135
+ spanlabels = tree1["tree"].spans_labels()
136
+ spans = tree2["tree"].spans()
137
+
138
+ for (i, j, label) in spanlabels:
139
+ if j - i != 1:
140
+ if label not in correct:
141
+ correct[label] = 0
142
+ total[label] = 0
143
+ if (i, j) in spans:
144
+ correct[label] += 1
145
+ total[label] += 1
146
+ except Exception as e:
147
+ print(e)
148
+ acc = {}
149
+ for label in total.keys():
150
+ acc[label] = correct[label] / total[label]
151
+ return acc
152
+
153
+
154
+ def label_recall_output(gold_standard, best_parse):
155
+ best_parse_trees = []
156
+ gold_standard_trees = []
157
+ for t1, t2 in zip(gold_standard, best_parse):
158
+ gold_standard_trees.append({"tree": unlinearize(t1)})
159
+ best_parse_trees.append({"tree": unlinearize(t2)})
160
+
161
+ dct = recall_by_label(gold_standard=gold_standard_trees, best_parse=best_parse_trees)
162
+ labels = ["SBAR", "NP", "VP", "PP", "ADJP", "ADVP"]
163
+ l = [{label: f"{recall * 100:.2f}"} for label, recall in dct.items() if label in labels]
164
+ df = pd.DataFrame([item.values() for item in l], index=[item.keys() for item in l], columns=["recall"])
165
+ df.index = df.index.map(lambda x: list(x)[0])
166
+ df_out = df.reindex(labels)
167
+ return df_out
168
+
169
+
170
+ if __name__ == "__main__":
171
+ import pandas as pd
172
+ from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
173
+ from weakly_supervised_parser.settings import PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, PTB_SAVE_TREES_PATH
174
+
175
+ best_parse = PTBDataset(PTB_SAVE_TREES_PATH + "inside_model_predictions.txt").retrieve_all_sentences()
176
+ gold_standard = PTBDataset(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH).retrieve_all_sentences()
177
+ print(label_recall_output(gold_standard, best_parse))
weakly_supervised_parser/utils/__init__.py ADDED
File without changes
weakly_supervised_parser/utils/cky_algorithm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from weakly_supervised_parser.tree.helpers import Tree
4
+
5
+
6
+ def CKY(sent_all, prob_s, label_s, verbose=False):
7
+ r"""
8
+ choose tree with maximum expected number of constituents,
9
+ or max \sum_{(i,j) \in tree} p((i,j) is constituent)
10
+ """
11
+
12
+ def backpt_to_tree(sent, backpt, label_table):
13
+ def to_tree(i, j):
14
+ if j - i == 1:
15
+ return Tree(sent[i], None, sent[i])
16
+ else:
17
+ k = backpt[i][j]
18
+ return Tree(label_table[i][j], [to_tree(i, k), to_tree(k, j)], None)
19
+
20
+ return to_tree(0, len(sent))
21
+
22
+ def to_table(value_s, i_s, j_s):
23
+ table = [[None for _ in range(np.max(j_s) + 1)] for _ in range(np.max(i_s) + 1)]
24
+ for value, i, j in zip(value_s, i_s, j_s):
25
+ table[i][j] = value
26
+ return table
27
+
28
+ # produce list of spans to pass to is_constituent, while keeping track of which sentence
29
+ sent_s, i_s, j_s = [], [], []
30
+ idx_all = []
31
+ for sent in sent_all:
32
+ start = len(sent_s)
33
+ for i in range(len(sent)):
34
+ for j in range(i + 1, len(sent) + 1):
35
+ sent_s.append(sent)
36
+ i_s.append(i)
37
+ j_s.append(j)
38
+ idx_all.append((start, len(sent_s)))
39
+
40
+ # feed spans to is_constituent
41
+ # prob_s, label_s = self.is_constituent(sent_s, i_s, j_s, verbose = verbose)
42
+
43
+ # given span probs, perform CKY to get best tree for each sentence.
44
+ tree_all, prob_all = [], []
45
+ for sent, idx in zip(sent_all, idx_all):
46
+ # first, use tables to keep track of things
47
+ k, l = idx
48
+ prob, label = prob_s[k:l], label_s[k:l]
49
+ i, j = i_s[k:l], j_s[k:l]
50
+
51
+ prob_table = to_table(prob, i, j)
52
+ label_table = to_table(label, i, j)
53
+
54
+ # perform cky using scores and backpointers
55
+ score_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))]
56
+ backpt_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))]
57
+ for i in range(len(sent)): # base case: single words
58
+ score_table[i][i + 1] = 1
59
+ for j in range(2, len(sent) + 1):
60
+ for i in range(j - 2, -1, -1):
61
+ best, argmax = -np.inf, None
62
+ for k in range(i + 1, j): # find splitpoint
63
+ score = score_table[i][k] + score_table[k][j]
64
+ if score > best:
65
+ best, argmax = score, k
66
+ score_table[i][j] = best + prob_table[i][j]
67
+ backpt_table[i][j] = argmax
68
+
69
+ tree = backpt_to_tree(sent, backpt_table, label_table)
70
+ tree_all.append(tree)
71
+ prob_all.append(prob_table)
72
+
73
+ return tree_all, prob_all
74
+
75
+
76
+ def get_best_parse(sentence, spans):
77
+ flattened_scores = []
78
+ for i in range(spans.shape[0]):
79
+ for j in range(spans.shape[1]):
80
+ if i > j:
81
+ continue
82
+ else:
83
+ flattened_scores.append(spans[i, j])
84
+ prob_s, label_s = flattened_scores, ["S"] * len(flattened_scores)
85
+ # print(prob_s, label_s)
86
+ trees, _ = CKY(sent_all=sentence, prob_s=prob_s, label_s=label_s)
87
+ s = str(trees[0])
88
+ # Replace previous occurrence of string
89
+ out = re.sub(r"(?<![^\s()])([^\s()]+)(?=\s+\1(?![^\s()]))", "S", s)
90
+ # best_parse = "(ROOT " + out + ")"
91
+ return out # best_parse
weakly_supervised_parser/utils/create_inside_outside_strings.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class InsideOutside:
2
+ def __init__(self, sentence):
3
+ self.sentence = sentence.split()
4
+ self.sentence_length = len(self.sentence)
5
+
6
+ def calculate_inside(self, idx_start, idx_end):
7
+ # get inside string
8
+ return self.sentence[idx_start:idx_end]
9
+
10
+ def calculate_outside(self, idx_start, idx_end):
11
+ # get outside string
12
+ if idx_start == 0 and idx_end == self.sentence_length:
13
+ left_outside = ["<s>"] # bos_token roberta # ["[UNK]"]
14
+ right_outside = ["</s>"] # eos_token roberta # ["[UNK]"]
15
+ elif idx_start == 0:
16
+ left_outside = ["<s>"] # ["[UNK]"]
17
+ right_outside = self.sentence[idx_end:]
18
+ elif idx_end == self.sentence_length:
19
+ left_outside = self.sentence[:idx_start]
20
+ right_outside = ["</s>"] # ["[UNK]"]
21
+ else:
22
+ left_outside = self.sentence[:idx_start]
23
+ right_outside = self.sentence[idx_end:]
24
+ return left_outside, right_outside
25
+
26
+ def create_inside_outside_matrix(self, ngram):
27
+ i, j = ngram[0][0], ngram[0][-1]
28
+ inside_string = self.calculate_inside(i, j)
29
+ outside_string = self.calculate_outside(i, j)
30
+ output_dict = {
31
+ "span": ngram[0],
32
+ "inside_string": " ".join(inside_string),
33
+ "left_outside_string": " ".join(outside_string[0]),
34
+ "right_outside_string": " ".join(outside_string[-1]),
35
+ }
36
+ inside_string_template = output_dict["inside_string"]
37
+ outside_string_template = (
38
+ output_dict["left_outside_string"].split()[-1] + " " + "<mask>" + " " + output_dict["right_outside_string"].split()[0]
39
+ )
40
+ return output_dict, inside_string_template, outside_string_template
weakly_supervised_parser/utils/distant_supervision.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, Counter
2
+ from nltk.corpus import stopwords
3
+
4
+
5
+ class RuleBasedHeuristic:
6
+ def __init__(self, sentence=None, corpus=None):
7
+ self.sentence = sentence
8
+ self.corpus = corpus
9
+
10
+ def add_contiguous_titlecase_words(self, row):
11
+ matches = []
12
+ dd = defaultdict(list)
13
+ count = 0
14
+ for i, j in zip(row, row[1:]):
15
+ if j[0] - i[0] == 1:
16
+ dd[count].append(i[-1] + " " + j[-1])
17
+ else:
18
+ count += 1
19
+ for key, value in dd.items():
20
+ if len(value) > 1:
21
+ out = value[0]
22
+ inter = ""
23
+ for item in value[1:]:
24
+ inter += " " + item.split()[-1]
25
+ matches.append(out + inter)
26
+ else:
27
+ matches.extend(value)
28
+ return matches
29
+
30
+ def augment_using_most_frequent_starting_token(self, N=1):
31
+ first_token = []
32
+ for sentence in self.corpus:
33
+ first_token.append(sentence.split()[0])
34
+ return Counter(first_token).most_common(N)
35
+
36
+ def get_top_tokens(self, top_most_common_ptb=None):
37
+ out = set(stopwords.words("english"))
38
+ if top_most_common_ptb:
39
+ out.update([token for token, counts in self.augment_using_most_frequent_starting_token(N=top_most_common_ptb)])
40
+ return out
weakly_supervised_parser/utils/populate_chart.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ from datasets.utils import set_progress_bar_enabled
5
+
6
+ from weakly_supervised_parser.utils.prepare_dataset import NGramify
7
+ from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside
8
+ from weakly_supervised_parser.utils.cky_algorithm import get_best_parse
9
+ from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
10
+ from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
11
+ from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH
12
+
13
+ # Disable Dataset.map progress bar
14
+ set_progress_bar_enabled(False)
15
+
16
+ # ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH)
17
+ # ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)]
18
+ ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'mightn', 'we', 'american', 'the', 'another', 'until', "aren't", 'when', 'if', 'am', 'over', 'ma', 'as', 'of', 'with', 'even', 'couldn', 'not', "needn't", 'where', 'there', 'isn', 'however', 'my', 'sales', 'here', 'at', 'yours', 'into', 'wouldn', 'officials', 'no', "hasn't", 'to', 'wasn', 'any', 'ours', 'out', 'each', "wasn't", 'is', 'and', 'me', 'off', 'once', "it's", 'they', 'most', 'also', 'through', 'hasn', 'our', 'or', 'after', "weren't", 'about', 'mr.', 'first', 'haven', 'needn', 'have', "isn't", 'now', "didn't", 'on', 'theirs', 'these', 'before', 'there', 'was', 'which', 'those', 'having', 'do', 'most', 'own', 'among', 'because', 'for', "should've", "shan't", 'so', 'being', 'few', 'too', 'to', 'at', 'people', 'her', 'meanwhile', 'both', 'down', 'doesn', 'below', 'mustn', 'an', 'two', 'more', 'japanese', 'ford', "you'd", 'about', 'but', 'doing', 'itself', 've', 'under', 'what', 'again', 'then', 'your', 'himself', 'now', 'against', 'just', 'does', 'net', "couldn't", 'that', 'he', 'revenue', 'because', 'yesterday', 'them', 'i', 'their', 'all', 'under', 'up', "haven't", 'while', "won't", 'it', 'more', 'it', 'ain', 'him', 'still', 'a', 'he', 'despite', 'should', 'during', 'nor', "shouldn't", 'such', "doesn't", 'are', "that'll", 'since', 'yourselves', 'such', 'those', 'after', 'weren', "you're", 'd', 'like', 'did', 'hadn', 'themselves', 'its', 'but', 'been', 's', "don't", 'these', 'they', 'this', 'his', "mightn't", 'moreover', 'how', 'new', 'above', 'ourselves', 'so', 'why', 'between', 'their', 'general', "wouldn't", 'who', 'i', 'in', 'don', 'shan', 'u.s.', 'ibm', 'separately', 'had', 'you', 'federal', 'if', 'our', 'and', 'only', 'y', 'many', 'one', 'no', 'though', 'won', 'last', 'from', 'each', 'traders', 'john', 'further', 'hers', 'both', "you've", "you'll", 'that', 'all', 'its', 'only', 'here', 'according', "mustn't", 'while', 'in', 'what', 'didn', 'when', 'some', 'on', 'can', 'yourself', 'herself', 'than', 'with', 'has', 'she', 'during', 'will', 'of', 'thus', 'you', 'very', 'o', 'investors', 'a', 'ms.', 'japan', 'were', 'the', 'we', 'm', 'as', 'll', 'be', 'by', 'other', 'yet', 'whom', 'some', 'indeed', 'other', "she's", "hadn't", 'by', 'earlier', 'for', 'instead', 'she', 'an', 't', 're', 'his', 'then', 'aren', 'although']
19
+ # ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
20
+ ptb_most_common_first_token = "the"
21
+
22
+
23
+ class PopulateCKYChart:
24
+ def __init__(self, sentence):
25
+ self.sentence = sentence
26
+ self.sentence_list = sentence.split()
27
+ self.sentence_length = len(sentence.split())
28
+ self.span_scores = np.zeros((self.sentence_length + 1, self.sentence_length + 1), dtype=float)
29
+ self.all_spans = NGramify(self.sentence).generate_ngrams(single_span=True, whole_span=True)
30
+
31
+ def compute_scores(self, model, predict_type, scale_axis, predict_batch_size, chunks=128):
32
+ inside_strings = []
33
+ outside_strings = []
34
+ inside_scores = []
35
+ outside_scores = []
36
+
37
+ for span in self.all_spans:
38
+ _, inside_string, outside_string = InsideOutside(sentence=self.sentence).create_inside_outside_matrix(span)
39
+ inside_strings.append(inside_string)
40
+ outside_strings.append(outside_string)
41
+
42
+ data = pd.DataFrame({"inside_sentence": inside_strings, "outside_sentence": outside_strings, "span": [span[0] for span in self.all_spans]})
43
+
44
+ if predict_type == "inside":
45
+
46
+ if data.shape[0] > chunks:
47
+ data_chunks = np.array_split(data, data.shape[0] // chunks)
48
+ for data_chunk in data_chunks:
49
+ inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
50
+ scale_axis=scale_axis,
51
+ predict_batch_size=predict_batch_size)[:, 1])
52
+ else:
53
+ inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
54
+ scale_axis=scale_axis,
55
+ predict_batch_size=predict_batch_size)[:, 1])
56
+
57
+ data["inside_scores"] = inside_scores
58
+ data.loc[
59
+ (data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token))
60
+ & (data["inside_sentence"].str.lower().str.split().str.len() == 2)
61
+ & (~data["inside_sentence"].str.lower().str.split().str[-1].isin(RuleBasedHeuristic().get_top_tokens())),
62
+ "inside_scores",
63
+ ] = 1
64
+
65
+ is_upper_or_title = all([item.istitle() or item.isupper() for item in self.sentence.split()])
66
+ is_stop = any([item for item in self.sentence.split() if item.lower() in ptb_top_100_common])
67
+
68
+ flags = is_upper_or_title and not is_stop
69
+
70
+ data["scores"] = data["inside_scores"]
71
+
72
+ elif predict_type == "outside":
73
+ outside_scores.extend(model.predict_proba(spans=data.rename(columns={"outside_sentence": "sentence"})[["sentence"]],
74
+ scale_axis=scale_axis,
75
+ predict_batch_size=predict_batch_size)[:, 1])
76
+ data["outside_scores"] = outside_scores
77
+ flags = False
78
+ data["scores"] = data["outside_scores"]
79
+
80
+ return flags, data
81
+
82
+ def fill_chart(self, model, predict_type, scale_axis, predict_batch_size, data=None):
83
+ if data is None:
84
+ flags, data = self.compute_scores(model, predict_type, scale_axis, predict_batch_size)
85
+ for span in self.all_spans:
86
+ for i in range(0, self.sentence_length):
87
+ for j in range(i + 1, self.sentence_length + 1):
88
+ if span[0] == (i, j):
89
+ self.span_scores[i, j] = data.loc[data["span"] == span[0], "scores"].item()
90
+ return flags, self.span_scores, data
91
+
92
+ def best_parse_tree(self, span_scores):
93
+ span_scores_cky_format = span_scores[:-1, 1:]
94
+ best_parse = get_best_parse(sentence=[self.sentence_list], spans=span_scores_cky_format)
95
+ return best_parse
weakly_supervised_parser/utils/prepare_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import pandas as pd
3
+
4
+ from sklearn.model_selection import train_test_split
5
+
6
+ from weakly_supervised_parser.utils.process_ptb import punctuation_words, currency_tags_words
7
+ from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
8
+
9
+
10
+ filterchars = punctuation_words + currency_tags_words
11
+ filterchars = [char for char in filterchars if char not in list(",;-") and char not in "``" and char not in "''"]
12
+
13
+
14
+ class NGramify:
15
+ def __init__(self, sentence):
16
+ self.sentence = sentence.split()
17
+ self.sentence_length = len(self.sentence)
18
+ self.ngrams = []
19
+
20
+ def generate_ngrams(self, single_span=True, whole_span=True):
21
+ # number of substrings possible is N*(N+1)/2
22
+ # exclude substring or spans of length 1 and length N
23
+ if single_span:
24
+ start = 1
25
+ else:
26
+ start = 2
27
+ if whole_span:
28
+ end = self.sentence_length + 1
29
+ else:
30
+ end = self.sentence_length
31
+ for n in range(start, end):
32
+ for i in range(self.sentence_length - n + 1):
33
+ self.ngrams.append(((i, i + n), self.sentence[i : i + n]))
34
+ return self.ngrams
35
+
36
+ def generate_all_possible_spans(self):
37
+ for n in range(2, self.sentence_length):
38
+ for i in range(self.sentence_length - n + 1):
39
+ if i > 0 and (i + n) < self.sentence_length:
40
+ self.ngrams.append(
41
+ (
42
+ (i, i + n),
43
+ " ".join(self.sentence[i : i + n]),
44
+ " ".join(self.sentence[0:i])
45
+ + " ("
46
+ + " ".join(self.sentence[i : i + n])
47
+ + ") "
48
+ + " ".join(self.sentence[i + n : self.sentence_length]),
49
+ )
50
+ )
51
+ elif i == 0:
52
+ self.ngrams.append(
53
+ (
54
+ (i, i + n),
55
+ " ".join(self.sentence[i : i + n]),
56
+ "(" + " ".join(self.sentence[i : i + n]) + ") " + " ".join(self.sentence[i + n : self.sentence_length]),
57
+ )
58
+ )
59
+ elif (i + n) == self.sentence_length:
60
+ self.ngrams.append(
61
+ (
62
+ (i, i + n),
63
+ " ".join(self.sentence[i : i + n]),
64
+ " ".join(self.sentence[0:i]) + " (" + " ".join(self.sentence[i : i + n]) + ")",
65
+ )
66
+ )
67
+ return self.ngrams
68
+
69
+
70
+ class DataLoaderHelper:
71
+ def __init__(self, input_file_object=None, output_file_object=None):
72
+ self.input_file_object = input_file_object
73
+ self.output_file_object = output_file_object
74
+
75
+ def read_lines(self):
76
+ with open(self.input_file_object, "r") as f:
77
+ lines = f.read().splitlines()
78
+ return lines
79
+
80
+ def __getitem__(self, index):
81
+ return self.read_lines()[index]
82
+
83
+ def write_lines(self, keys, values):
84
+ with open(self.output_file_object, "w", newline="\n") as output_file:
85
+ dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")
86
+ dict_writer.writeheader()
87
+ dict_writer.writerows(values)
88
+
89
+
90
+ class PTBDataset:
91
+ def __init__(self, data_path):
92
+ self.data = pd.read_csv(data_path, sep="\t", header=None, names=["sentence"])
93
+ self.data["sentence"] = self.data
94
+
95
+ def __len__(self):
96
+ return len(self.data)
97
+
98
+ def __getitem__(self, index):
99
+ return self.data["sentence"].loc[index]
100
+
101
+ def retrieve_all_sentences(self, N=None):
102
+ if N:
103
+ return self.data["sentence"].iloc[:N].tolist()
104
+ return self.data["sentence"].tolist()
105
+
106
+ def preprocess(self):
107
+ self.data["sentence"] = self.data["sentence"].apply(
108
+ lambda row: " ".join([sentence for sentence in row.split() if sentence not in filterchars])
109
+ )
110
+ return self.data
111
+
112
+ def seed_bootstrap_constituent(self):
113
+ whole_span_slice = self.data["sentence"]
114
+ func = lambda x: RuleBasedHeuristic().add_contiguous_titlecase_words(
115
+ row=[(index, character) for index, character in enumerate(x) if character.istitle() or "'" in character]
116
+ )
117
+ titlecase_matches = [item for sublist in self.data["sentence"].str.split().apply(func).tolist() for item in sublist if len(item.split()) > 1]
118
+ titlecase_matches_df = pd.Series(titlecase_matches)
119
+ titlecase_matches_df = titlecase_matches_df[~titlecase_matches_df.str.split().str[0].str.contains("'")].str.replace("''", "")
120
+ most_frequent_start_token = RuleBasedHeuristic(corpus=self.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0]
121
+ most_frequent_start_token_df = titlecase_matches_df[titlecase_matches_df.str.startswith(most_frequent_start_token)].str.lower()
122
+ constituent_samples = pd.DataFrame(dict(sentence=pd.concat([whole_span_slice, titlecase_matches_df, most_frequent_start_token_df]), label=1))
123
+ return constituent_samples
124
+
125
+ def seed_bootstrap_distituent(self):
126
+ avg_sent_len = int(self.data["sentence"].str.split().str.len().mean())
127
+ last_but_one_slice = self.data["sentence"].str.split().str[:-1].str.join(" ")
128
+ last_but_two_slice = self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 10]["sentence"].str.split().str[:-2].str.join(" ")
129
+ last_but_three_slice = (
130
+ self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 20]["sentence"].str.split().str[:-3].str.join(" ")
131
+ )
132
+ last_but_four_slice = (
133
+ self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 30]["sentence"].str.split().str[:-4].str.join(" ")
134
+ )
135
+ last_but_five_slice = (
136
+ self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 40]["sentence"].str.split().str[:-5].str.join(" ")
137
+ )
138
+ last_but_six_slice = self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 50]["sentence"].str.split().str[:-6].str.join(" ")
139
+ distituent_samples = pd.DataFrame(
140
+ dict(
141
+ sentence=pd.concat(
142
+ [
143
+ last_but_one_slice,
144
+ last_but_two_slice,
145
+ last_but_three_slice,
146
+ last_but_four_slice,
147
+ last_but_five_slice,
148
+ last_but_six_slice,
149
+ ]
150
+ ),
151
+ label=0,
152
+ )
153
+ )
154
+ return distituent_samples
155
+
156
+ def train_validation_split(self, seed, test_size=0.5, shuffle=True):
157
+ self.preprocess()
158
+ bootstrap_constituent_samples = self.seed_bootstrap_constituent()
159
+ bootstrap_distituent_samples = self.seed_bootstrap_distituent()
160
+ df = pd.concat([bootstrap_constituent_samples, bootstrap_distituent_samples], ignore_index=True)
161
+ df = df.drop_duplicates(subset=["sentence"]).dropna(subset=["sentence"])
162
+ df["sentence"] = df["sentence"].str.strip()
163
+ df = df[df["sentence"].str.split().str.len() > 1]
164
+ train, validation = train_test_split(df, test_size=test_size, random_state=seed, shuffle=shuffle)
165
+ return train.head(8000), validation.head(2000)