TimeQA_demo / app.py
czq's picture
demo
9a92410
raw
history blame
No virus
4.21 kB
# 使用gradio开发QA的可视化demo
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer
import torch
from torch import nn
from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate
class BigBirdNullHead(nn.Module):
"""Head for question answering tasks."""
def __init__(self, config):
super().__init__()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.intermediate = BigBirdIntermediate(config)
self.output = BigBirdOutput(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
def forward(self, encoder_output):
hidden_states = self.dropout(encoder_output)
hidden_states = self.intermediate(hidden_states)
hidden_states = self.output(hidden_states, encoder_output)
logits = self.qa_outputs(hidden_states)
return logits
model_path = '/data1/chenzq/demo/checkpoint-epoch-best'
class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
def __init__(self, config, model_id):
super().__init__(config)
self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id,
config=self.config, add_pooling_layer=True)
self.null_classifier = BigBirdNullHead(self.bertqa.config)
self.contrastive_mlp = nn.Sequential(
nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size),
)
def forward(self, **kwargs):
if self.training:
null_labels = kwargs['is_impossible']
del kwargs['is_impossible']
outputs = self.bertqa(**kwargs)
pooler_output = outputs.pooler_output
null_logits = self.null_classifier(pooler_output)
loss_fct = nn.CrossEntropyLoss()
null_loss = loss_fct(null_logits, null_labels)
outputs.loss = outputs.loss + null_loss
return outputs.to_tuple()
else:
outputs = self.bertqa(**kwargs)
pooler_output = outputs.pooler_output
null_logits = self.null_classifier(pooler_output)
return (outputs.start_logits, outputs.end_logits, null_logits)
model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
config = BigBirdConfig.from_pretrained(model_id)
model = BigBirdForQuestionAnsweringWithNull(config, model_id)
model.to('cuda')
model.eval()
model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
tokenizer = BigBirdTokenizer.from_pretrained(model_path)
def main(question, context):
# 编码输入
text = question + " [SEP] " + context
inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
inputs.to('cuda')
# 预测答案
outputs = model(**inputs)
start_scores = outputs[0]
end_scores = outputs[1]
null_scores = outputs[2]
# 解码答案
is_impossible = null_scores.argmax().item()
if is_impossible:
return "No Answer"
else:
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
return answer
with gr.Blocks() as demo:
gr.Markdown("""# Question Answerer!""")
with gr.Row():
with gr.Column():
# options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model")
text1 = gr.Textbox(
label="Question",
lines=1,
value="Who does Cristiano Ronaldo play for?",
)
text2 = gr.Textbox(
label="Context",
lines=3,
value="Cristiano Ronaldo is a player for Manchester United",
)
output = gr.Textbox()
b1 = gr.Button("Ask Question!")
b1.click(main, inputs=[text1, text2], outputs=output)
# gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""")
if __name__ == "__main__":
demo.launch(share=True)