File size: 7,901 Bytes
47c0211
 
 
 
 
 
 
 
 
 
245d478
 
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
 
 
 
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from argparse import ArgumentParser
from loguru import logger

from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH

from weakly_supervised_parser.model.span_classifier import LightningModel


class Predictor:
    def __init__(self, sentence):
        self.sentence = sentence
        self.sentence_list = sentence.split()

    def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
        unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type, 
                                                                                                  model=model, 
                                                                                                  scale_axis=scale_axis, 
                                                                                                  predict_batch_size=predict_batch_size)

        if unique_tokens_flag:
            best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
            logger.info("BEST PARSE", best_parse)
        else:
            best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
        if return_df:
            return best_parse, df
        return best_parse


def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
    best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type, 
                                                                    model=model, 
                                                                    scale_axis=scale_axis, 
                                                                    predict_batch_size=predict_batch_size,
                                                                    return_df=True)
    gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
    sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
    if sentence_f1 < 25.0:
        logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
    else:
        logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
    if return_df:
        return best_parse, df
    else:
        return best_parse


def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
    _, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
    _, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
    df = df_inside.copy()
    df["scores"] = df_inside["scores"] * df_outside["scores"]
    _, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
    best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
    gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
    sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
    if sentence_f1 < 25.0:
        logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
    else:
        logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
    return best_parse


def main():
    parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)

    group = parser.add_mutually_exclusive_group(required=True)

    group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")

    group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")

    group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")

    group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")

    parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")

    parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
    
    parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
    
    parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")

    parser.add_argument(
        "--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
    )

    parser.add_argument(
        "--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
    )

    args = parser.parse_args()

    if args.use_inside:
        pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt"
        max_seq_length = args.inside_max_seq_length

    if args.use_inside_self_train:
        pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
        max_seq_length = args.inside_max_seq_length

    if args.use_outside:
        pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
        max_seq_length = args.outside_max_seq_length

    if args.use_inside_outside_co_train:
        inside_pre_trained_model_path = "inside_model_co_trained.onnx"
        inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
        inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)

        outside_pre_trained_model_path = "outside_model_co_trained.onnx"
        outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
        outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
    else:
        # model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
        # model.load_model(pre_trained_model_path=pre_trained_model_path)
        
        model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path)

    if args.use_inside or args.use_inside_self_train:
        predict_type = "inside"

    if args.use_outside:
        predict_type = "outside"

    with open(args.save_path, "w") as out_file:
        test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
        test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
        for test_index, test_sentence in enumerate(test_sentences):
            if args.use_inside_outside_co_train:
                best_parse = process_co_train_test_sample(
                    test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
                )
            else:
                best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
                                                 scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)

            out_file.write(best_parse + "\n")


if __name__ == "__main__":
    main()