nickil's picture
Upload app.py
4d50603
raw
history blame
No virus
4.23 kB
import os
import matplotlib
matplotlib.use('agg')
from PIL import Image
import gradio
import benepar
import spacy
import nltk
from nltk.tree import Tree
from nltk.draw.tree import TreeView
from huggingface_hub import hf_hub_url, cached_download
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.inference import Predictor
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.model.span_classifier import LightningModel
if __name__ == "__main__":
nltk.download('stopwords')
benepar.download('benepar_en3')
nlp = spacy.load("en_core_web_md")
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
# inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.ckpt", revision="main")
inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
# inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
# inside_outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
# inside_outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "inside_outside_model.onnx")
def predict(sentence, model):
gold_standard = list(nlp(sentence).sents)[0]._.parse_string
if model == "inside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside", model=inside_model, scale_axis=1, predict_batch_size=128)
elif model == "outside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="outside", model=outside_model, scale_axis=1, predict_batch_size=128)
elif model == "inside-outside":
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
TreeView(Tree.fromstring(gold_standard))._cframe.print_to_file('gold_standard.ps')
TreeView(Tree.fromstring(best_parse))._cframe.print_to_file('best_parse.ps')
os.system('convert gold_standard.ps gold_standard.png')
os.system('convert best_parse.ps best_parse.png')
gold_standard_img = Image.open("gold_standard.png")
best_parse_img = Image.open("best_parse.png")
return gold_standard_img, best_parse_img, f"{sentence_f1:.2f}"
iface = gradio.Interface(
title="Co-training an Unsupervised Constituency Parser with Weak Supervision",
description="Demo for the repository - [weakly-supervised-parsing](https://github.com/Nickil21/weakly-supervised-parsing) (ACL Findings 2022)",
theme="default",
article="""<h4 class='text-lg font-semibold my-2'>Note</h4>
- We use a strong supervised parsing model `benepar_en3` which is based on T5-small to compute the gold parse.<br>
- Sentence F1 score corresponds to the macro F1 score.
""",
allow_flagging="never",
fn=predict,
inputs=[
gradio.inputs.Textbox(label="Sentence", placeholder="Enter a sentence in English", lines=2),
gradio.inputs.Radio(["inside", "outside", "inside-outside"], default="inside", label="Choose Model"),
],
outputs=[
gradio.outputs.Image(label="Gold Parse Tree"),
gradio.outputs.Image(label="Predicted Parse Tree"),
gradio.outputs.Textbox(label="F1 score"),
],
examples=[
["Russia 's war on Ukraine unsettles investors expecting carve-out deal uptick for 2022 .", "inside-outside"],
["Bitcoin community under pressure to cut energy use .", "inside"],
],
)
iface.launch(share=True)