# 使用gradio开发QA的可视化demo import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer import torch from torch import nn from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate class BigBirdNullHead(nn.Module): """Head for question answering tasks.""" def __init__(self, config): super().__init__() self.dropout = nn.Dropout(config.hidden_dropout_prob) self.intermediate = BigBirdIntermediate(config) self.output = BigBirdOutput(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) def forward(self, encoder_output): hidden_states = self.dropout(encoder_output) hidden_states = self.intermediate(hidden_states) hidden_states = self.output(hidden_states, encoder_output) logits = self.qa_outputs(hidden_states) return logits model_path = '/data1/chenzq/demo/checkpoint-epoch-best' class BigBirdForQuestionAnsweringWithNull(PreTrainedModel): def __init__(self, config, model_id): super().__init__(config) self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id, config=self.config, add_pooling_layer=True) self.null_classifier = BigBirdNullHead(self.bertqa.config) self.contrastive_mlp = nn.Sequential( nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size), ) def forward(self, **kwargs): if self.training: null_labels = kwargs['is_impossible'] del kwargs['is_impossible'] outputs = self.bertqa(**kwargs) pooler_output = outputs.pooler_output null_logits = self.null_classifier(pooler_output) loss_fct = nn.CrossEntropyLoss() null_loss = loss_fct(null_logits, null_labels) outputs.loss = outputs.loss + null_loss return outputs.to_tuple() else: outputs = self.bertqa(**kwargs) pooler_output = outputs.pooler_output null_logits = self.null_classifier(pooler_output) return (outputs.start_logits, outputs.end_logits, null_logits) model_id = 'vasudevgupta/bigbird-roberta-natural-questions' config = BigBirdConfig.from_pretrained(model_id) model = BigBirdForQuestionAnsweringWithNull(config, model_id) model.to('cuda') model.eval() model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备 tokenizer = BigBirdTokenizer.from_pretrained(model_path) def main(question, context): # 编码输入 text = question + " [SEP] " + context inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt") inputs.to('cuda') # 预测答案 outputs = model(**inputs) start_scores = outputs[0] end_scores = outputs[1] null_scores = outputs[2] # 解码答案 is_impossible = null_scores.argmax().item() if is_impossible: return "No Answer" else: answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) + 1 answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) return answer with gr.Blocks() as demo: gr.Markdown("""# Question Answerer!""") with gr.Row(): with gr.Column(): # options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model") text1 = gr.Textbox( label="Question", lines=1, value="Who does Cristiano Ronaldo play for?", ) text2 = gr.Textbox( label="Context", lines=3, value="Cristiano Ronaldo is a player for Manchester United", ) output = gr.Textbox() b1 = gr.Button("Ask Question!") b1.click(main, inputs=[text1, text2], outputs=output) # gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""") if __name__ == "__main__": demo.launch(share=True)