nickil's picture
update model ckpt
245d478
from argparse import ArgumentParser
from loguru import logger
from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
from weakly_supervised_parser.model.span_classifier import LightningModel
class Predictor:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = sentence.split()
def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type,
model=model,
scale_axis=scale_axis,
predict_batch_size=predict_batch_size)
if unique_tokens_flag:
best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
logger.info("BEST PARSE", best_parse)
else:
best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
if return_df:
return best_parse, df
return best_parse
def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type,
model=model,
scale_axis=scale_axis,
predict_batch_size=predict_batch_size,
return_df=True)
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
if sentence_f1 < 25.0:
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
else:
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
if return_df:
return best_parse, df
else:
return best_parse
def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
_, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
_, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
df = df_inside.copy()
df["scores"] = df_inside["scores"] * df_outside["scores"]
_, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
if sentence_f1 < 25.0:
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
else:
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
return best_parse
def main():
parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")
group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")
group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")
group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")
parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")
parser.add_argument(
"--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
)
parser.add_argument(
"--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
)
args = parser.parse_args()
if args.use_inside:
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt"
max_seq_length = args.inside_max_seq_length
if args.use_inside_self_train:
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
max_seq_length = args.inside_max_seq_length
if args.use_outside:
pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
max_seq_length = args.outside_max_seq_length
if args.use_inside_outside_co_train:
inside_pre_trained_model_path = "inside_model_co_trained.onnx"
inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)
outside_pre_trained_model_path = "outside_model_co_trained.onnx"
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
else:
# model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
# model.load_model(pre_trained_model_path=pre_trained_model_path)
model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path)
if args.use_inside or args.use_inside_self_train:
predict_type = "inside"
if args.use_outside:
predict_type = "outside"
with open(args.save_path, "w") as out_file:
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
for test_index, test_sentence in enumerate(test_sentences):
if args.use_inside_outside_co_train:
best_parse = process_co_train_test_sample(
test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
)
else:
best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)
out_file.write(best_parse + "\n")
if __name__ == "__main__":
main()