File size: 4,213 Bytes
9a92410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# 使用gradio开发QA的可视化demo

import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer
import torch
from torch import nn
from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate

class BigBirdNullHead(nn.Module):
    """Head for question answering tasks."""

    def __init__(self, config):
        super().__init__()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.intermediate = BigBirdIntermediate(config)
        self.output = BigBirdOutput(config)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

    def forward(self, encoder_output):
        hidden_states = self.dropout(encoder_output)
        hidden_states = self.intermediate(hidden_states)
        hidden_states = self.output(hidden_states, encoder_output)
        logits = self.qa_outputs(hidden_states)
        return logits


model_path = '/data1/chenzq/demo/checkpoint-epoch-best'

class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
    def __init__(self, config, model_id):
        super().__init__(config)
        self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id,
            config=self.config, add_pooling_layer=True)
        self.null_classifier = BigBirdNullHead(self.bertqa.config)
        self.contrastive_mlp = nn.Sequential(
            nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size),
        )

    def forward(self, **kwargs):
        if self.training:
            null_labels = kwargs['is_impossible']
            del kwargs['is_impossible']
            outputs = self.bertqa(**kwargs)
            pooler_output = outputs.pooler_output
            null_logits = self.null_classifier(pooler_output)
            loss_fct = nn.CrossEntropyLoss()
            null_loss = loss_fct(null_logits, null_labels)


            outputs.loss = outputs.loss + null_loss 

            return outputs.to_tuple()
        else:
            outputs = self.bertqa(**kwargs)
            pooler_output = outputs.pooler_output
            null_logits = self.null_classifier(pooler_output)

            return (outputs.start_logits, outputs.end_logits, null_logits)
model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
config = BigBirdConfig.from_pretrained(model_id)
model = BigBirdForQuestionAnsweringWithNull(config, model_id)
model.to('cuda')
model.eval()

model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备

tokenizer = BigBirdTokenizer.from_pretrained(model_path)

def main(question, context):
    # 编码输入
    text = question + " [SEP] " + context
    inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
    inputs.to('cuda')
    # 预测答案
    outputs = model(**inputs)
    start_scores = outputs[0]
    end_scores = outputs[1]
    null_scores = outputs[2]
    # 解码答案
    is_impossible = null_scores.argmax().item()
    if is_impossible:
        return "No Answer"
    else:
        answer_start = torch.argmax(start_scores)
        answer_end = torch.argmax(end_scores) + 1
        answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
        return answer

  

with gr.Blocks() as demo:

  gr.Markdown("""# Question Answerer!""")
  with gr.Row():
    with gr.Column():
    #   options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model")
      text1 = gr.Textbox(
            label="Question",
            lines=1,
            value="Who does Cristiano Ronaldo play for?",
        )
      text2 = gr.Textbox(
            label="Context",
            lines=3,
            value="Cristiano Ronaldo is a player for Manchester United",
        )
      output = gr.Textbox()
      b1 = gr.Button("Ask Question!")
      b1.click(main, inputs=[text1, text2], outputs=output)
#   gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""")


if __name__ == "__main__":

    demo.launch(share=True)