File size: 4,219 Bytes
4d50603
 
 
 
 
 
 
47c0211
 
 
 
705fc84
4d50603
47c0211
 
 
 
 
 
 
245d478
 
4d50603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c48272
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import matplotlib
matplotlib.use('agg')

from PIL import Image


import gradio
import benepar
import spacy
import nltk
from nltk.tree import Tree
from nltk.draw.tree import TreeView

from huggingface_hub import hf_hub_url, cached_download

from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.inference import Predictor
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier

from weakly_supervised_parser.model.span_classifier import LightningModel


if __name__ == "__main__":
    nltk.download('stopwords')
    benepar.download('benepar_en3')

    nlp = spacy.load("en_core_web_md")
    nlp.add_pipe("benepar", config={"model": "benepar_en3"})

    # inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
    fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.ckpt", revision="main")
    inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
    # inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))

    # outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
    # outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")

    # inside_outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
    # inside_outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "inside_outside_model.onnx")


    def predict(sentence, model):
        gold_standard = list(nlp(sentence).sents)[0]._.parse_string
        if model == "inside":
            best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside", model=inside_model, scale_axis=1, predict_batch_size=128)
        elif model == "outside":
            best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="outside", model=outside_model, scale_axis=1, predict_batch_size=128)
        elif model == "inside-outside":
            best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
        sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
        TreeView(Tree.fromstring(gold_standard))._cframe.print_to_file('gold_standard.ps')
        TreeView(Tree.fromstring(best_parse))._cframe.print_to_file('best_parse.ps')
        os.system('convert gold_standard.ps gold_standard.png')
        os.system('convert best_parse.ps best_parse.png')
        gold_standard_img = Image.open("gold_standard.png")
        best_parse_img = Image.open("best_parse.png")
        return gold_standard_img, best_parse_img, f"{sentence_f1:.2f}"


    iface = gradio.Interface(
        title="Co-training an Unsupervised Constituency Parser with Weak Supervision",
        description="Demo for the repository - [weakly-supervised-parsing](https://github.com/Nickil21/weakly-supervised-parsing) (ACL Findings 2022)",
        theme="default",
        article="""<h4 class='text-lg font-semibold my-2'>Note</h4>
        - We use a strong supervised parsing model `benepar_en3` which is based on T5-small to compute the gold parse.<br>
        - Sentence F1 score corresponds to the macro F1 score.
        """,
        allow_flagging="never",
        fn=predict,
        inputs=[
            gradio.inputs.Textbox(label="Sentence", placeholder="Enter a sentence in English", lines=2),
            gradio.inputs.Radio(["inside", "outside", "inside-outside"], default="inside", label="Choose Model"),
        ],
        outputs=[
            gradio.outputs.Image(label="Gold Parse Tree"),
            gradio.outputs.Image(label="Predicted Parse Tree"),
            gradio.outputs.Textbox(label="F1 score"),
        ],
        examples=[
            ["Russia 's war on Ukraine unsettles investors expecting carve-out deal uptick for 2022 .", "inside-outside"],
            ["Bitcoin community under pressure to cut energy use .", "inside"],
        ],
    )

    iface.launch()