nickil's picture
update app
b0ceaae
raw
history blame
No virus
3.46 kB
import gradio
import benepar
import spacy
import nltk
from nltk.tree import Tree
nltk.download('stopwords')
from huggingface_hub import hf_hub_url, cached_download
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.inference import Predictor
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.model.span_classifier import LightningModel
benepar.download('benepar_en3')
nlp = spacy.load("en_core_web_md")
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.onnx", revision="main")
# inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
# inside_outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
# inside_outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "inside_outside_model.onnx")
def predict(sentence, model):
gold_standard = list(nlp(sentence).sents)[0]._.parse_string
if model == "inside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside", model=inside_model, scale_axis=1, predict_batch_size=128)
elif model == "outside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="outside", model=outside_model, scale_axis=1, predict_batch_size=128)
elif model == "inside-outside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
return gold_standard, best_parse, f"{sentence_f1:.2f}"
iface = gradio.Interface(
title="Co-training an Unsupervised Constituency Parser with Weak Supervision",
description="Demo for the repository - [weakly-supervised-parsing](https://github.com/Nickil21/weakly-supervised-parsing) (ACL Findings 2022)",
theme="default",
article="""<h4 class='text-lg font-semibold my-2'>Note</h4>
- We use a strong supervised parsing model `benepar_en3` which is based on T5-small to compute the gold parse.<br>
- Sentence F1 score corresponds to the macro F1 score.
""",
allow_flagging="never",
fn=predict,
inputs=[
gradio.inputs.Textbox(label="Sentence", placeholder="Enter a sentence in English"),
gradio.inputs.Radio(["inside", "outside", "inside-outside"], default="inside", label="Choose Model"),
],
outputs=[
gradio.outputs.Textbox(label="Gold Parse Tree"),
gradio.outputs.Textbox(label="Predicted Parse Tree"),
gradio.outputs.Textbox(label="F1 score"),
],
examples=[
["Russia 's war on Ukraine unsettles investors expecting carve-out deal uptick for 2022 .", "inside-outside"],
["Bitcoin community under pressure to cut energy use .", "inside"],
],
)
iface.launch()