File size: 3,505 Bytes
27f94c6
 
 
 
 
 
 
 
54ac152
 
27f94c6
2d83c30
4dc3331
 
27f94c6
 
d61e332
 
 
 
 
 
 
 
 
 
27f94c6
 
d61e332
 
 
27f94c6
 
 
 
 
 
 
 
d61e332
27f94c6
d61e332
27f94c6
 
 
 
4dc3331
54ac152
 
 
 
 
27f94c6
 
 
 
 
54ac152
 
 
27f94c6
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc3331
 
27f94c6
 
 
 
 
 
 
 
 
496462c
 
 
 
 
 
 
 
27f94c6
 
 
 
496462c
27f94c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time

import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from onnxruntime import InferenceSession
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

MAX_SEQUENCE_LENGTH = 512

models = {
    "Base model": "madlag/bert-large-uncased-whole-word-masking-finetuned-squadv2",
    "Pruned model": "madlag/bert-large-uncased-wwm-squadv2-x2.63-f82.6-d16-hybrid-v1",
    "Pruned ONNX Optimized FP16": "tryolabs/bert-large-uncased-wwm-squadv2-optimized-f16",
}

loaded_models = {
    "Pruned ONNX Optimized FP16": hf_hub_download(
        repo_id=models["Pruned ONNX Optimized FP16"], filename="model.onnx"
    ),
    "Base model": AutoModelForQuestionAnswering.from_pretrained(models["Base model"]),
    "Pruned model": AutoModelForQuestionAnswering.from_pretrained(
        models["Pruned model"]
    ),
}


def run_ort_inference(model_name, inputs):
    sess = InferenceSession(
        loaded_models[model_name], providers=["CPUExecutionProvider"]
    )
    start_time = time.time()
    output = sess.run(None, input_feed=inputs)
    end_time = time.time()
    return (output[0], output[1]), (end_time - start_time)


def run_normal_hf(model_name, inputs):
    start_time = time.time()
    output = loaded_models[model_name](**inputs).values()
    end_time = time.time()
    return output, (end_time - start_time)


def inference(model_name, context, question):
    tokenizer = AutoTokenizer.from_pretrained(models[model_name])
    if model_name == "Pruned ONNX Optimized FP16":
        inputs = dict(
            tokenizer(
                question, context, return_tensors="np", max_length=MAX_SEQUENCE_LENGTH
            )
        )
        output, inference_time = run_ort_inference(model_name, inputs)
        answer_start_scores, answer_end_scores = torch.tensor(output[0]), torch.tensor(
            output[1]
        )
    else:
        inputs = tokenizer(
            question, context, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH
        )
        output, inference_time = run_normal_hf(model_name, inputs)
        answer_start_scores, answer_end_scores = output

    input_ids = inputs["input_ids"].tolist()[0]
    answer_start = torch.argmax(answer_start_scores)
    answer_end = torch.argmax(answer_end_scores) + 1
    answer = tokenizer.convert_tokens_to_string(
        tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])
    )

    return answer, f"{inference_time:.4f}s"


model_field = gr.Dropdown(
    choices=["Base model", "Pruned model", "Pruned ONNX Optimized FP16"],
    value="Pruned ONNX Optimized FP16",
    label="Model",
)
input_text_field = gr.Textbox(placeholder="Enter the text here", label="Text")
input_question_field = gr.Text(placeholder="Enter the question here", label="Question")

output_model = gr.Text(label="Model output")
output_inference_time = gr.Text(label="Inference time in seconds")


examples = [
    [
        "Pruned ONNX Optimized FP16",
        "The first little pig was very lazy. He didn't want to work at all and he built his house out of straw. The second little pig worked a little bit harder but he was somewhat lazy too and he built his house out of sticks. Then, they sang and danced and played together the rest of the day.",
        "Who worked a little bit harder?",
    ]
]

demo = gr.Interface(
    inference,
    inputs=[model_field, input_text_field, input_question_field],
    outputs=[output_model, output_inference_time],
    examples=examples,
)

demo.launch()