Spaces:
Sleeping
Sleeping
demo
Browse files
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 使用gradio开发QA的可视化demo
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate
|
8 |
+
|
9 |
+
class BigBirdNullHead(nn.Module):
|
10 |
+
"""Head for question answering tasks."""
|
11 |
+
|
12 |
+
def __init__(self, config):
|
13 |
+
super().__init__()
|
14 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
15 |
+
self.intermediate = BigBirdIntermediate(config)
|
16 |
+
self.output = BigBirdOutput(config)
|
17 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
18 |
+
|
19 |
+
def forward(self, encoder_output):
|
20 |
+
hidden_states = self.dropout(encoder_output)
|
21 |
+
hidden_states = self.intermediate(hidden_states)
|
22 |
+
hidden_states = self.output(hidden_states, encoder_output)
|
23 |
+
logits = self.qa_outputs(hidden_states)
|
24 |
+
return logits
|
25 |
+
|
26 |
+
|
27 |
+
model_path = '/data1/chenzq/demo/checkpoint-epoch-best'
|
28 |
+
|
29 |
+
class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
|
30 |
+
def __init__(self, config, model_id):
|
31 |
+
super().__init__(config)
|
32 |
+
self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id,
|
33 |
+
config=self.config, add_pooling_layer=True)
|
34 |
+
self.null_classifier = BigBirdNullHead(self.bertqa.config)
|
35 |
+
self.contrastive_mlp = nn.Sequential(
|
36 |
+
nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size),
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, **kwargs):
|
40 |
+
if self.training:
|
41 |
+
null_labels = kwargs['is_impossible']
|
42 |
+
del kwargs['is_impossible']
|
43 |
+
outputs = self.bertqa(**kwargs)
|
44 |
+
pooler_output = outputs.pooler_output
|
45 |
+
null_logits = self.null_classifier(pooler_output)
|
46 |
+
loss_fct = nn.CrossEntropyLoss()
|
47 |
+
null_loss = loss_fct(null_logits, null_labels)
|
48 |
+
|
49 |
+
|
50 |
+
outputs.loss = outputs.loss + null_loss
|
51 |
+
|
52 |
+
return outputs.to_tuple()
|
53 |
+
else:
|
54 |
+
outputs = self.bertqa(**kwargs)
|
55 |
+
pooler_output = outputs.pooler_output
|
56 |
+
null_logits = self.null_classifier(pooler_output)
|
57 |
+
|
58 |
+
return (outputs.start_logits, outputs.end_logits, null_logits)
|
59 |
+
model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
|
60 |
+
config = BigBirdConfig.from_pretrained(model_id)
|
61 |
+
model = BigBirdForQuestionAnsweringWithNull(config, model_id)
|
62 |
+
model.to('cuda')
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
|
66 |
+
|
67 |
+
tokenizer = BigBirdTokenizer.from_pretrained(model_path)
|
68 |
+
|
69 |
+
def main(question, context):
|
70 |
+
# 编码输入
|
71 |
+
text = question + " [SEP] " + context
|
72 |
+
inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
|
73 |
+
inputs.to('cuda')
|
74 |
+
# 预测答案
|
75 |
+
outputs = model(**inputs)
|
76 |
+
start_scores = outputs[0]
|
77 |
+
end_scores = outputs[1]
|
78 |
+
null_scores = outputs[2]
|
79 |
+
# 解码答案
|
80 |
+
is_impossible = null_scores.argmax().item()
|
81 |
+
if is_impossible:
|
82 |
+
return "No Answer"
|
83 |
+
else:
|
84 |
+
answer_start = torch.argmax(start_scores)
|
85 |
+
answer_end = torch.argmax(end_scores) + 1
|
86 |
+
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
|
87 |
+
return answer
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
with gr.Blocks() as demo:
|
92 |
+
|
93 |
+
gr.Markdown("""# Question Answerer!""")
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column():
|
96 |
+
# options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model")
|
97 |
+
text1 = gr.Textbox(
|
98 |
+
label="Question",
|
99 |
+
lines=1,
|
100 |
+
value="Who does Cristiano Ronaldo play for?",
|
101 |
+
)
|
102 |
+
text2 = gr.Textbox(
|
103 |
+
label="Context",
|
104 |
+
lines=3,
|
105 |
+
value="Cristiano Ronaldo is a player for Manchester United",
|
106 |
+
)
|
107 |
+
output = gr.Textbox()
|
108 |
+
b1 = gr.Button("Ask Question!")
|
109 |
+
b1.click(main, inputs=[text1, text2], outputs=output)
|
110 |
+
# gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""")
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
|
115 |
+
demo.launch(share=True)
|
checkpoint-epoch-best/config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "vasudevgupta/bigbird-roberta-natural-questions",
|
3 |
+
"architectures": [
|
4 |
+
"BigBirdForQuestionAnsweringWithNull"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"attention_type": "block_sparse",
|
8 |
+
"block_size": 64,
|
9 |
+
"bos_token_id": 1,
|
10 |
+
"classifier_dropout": null,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"hidden_act": "gelu_fast",
|
13 |
+
"hidden_dropout_prob": 0.1,
|
14 |
+
"hidden_size": 768,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 3072,
|
17 |
+
"layer_norm_eps": 1e-12,
|
18 |
+
"max_position_embeddings": 4096,
|
19 |
+
"model_type": "big_bird",
|
20 |
+
"num_attention_heads": 12,
|
21 |
+
"num_hidden_layers": 12,
|
22 |
+
"num_random_blocks": 3,
|
23 |
+
"pad_token_id": 0,
|
24 |
+
"position_embedding_type": "absolute",
|
25 |
+
"rescale_embeddings": false,
|
26 |
+
"sep_token_id": 66,
|
27 |
+
"torch_dtype": "float32",
|
28 |
+
"transformers_version": "4.18.0",
|
29 |
+
"type_vocab_size": 2,
|
30 |
+
"use_bias": true,
|
31 |
+
"use_cache": true,
|
32 |
+
"vocab_size": 50358
|
33 |
+
}
|
checkpoint-epoch-best/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f56923a86a3827a575c0ba614808b6a7b58cf76a631f0e8f5529cc73583973a6
|
3 |
+
size 550147981
|
checkpoint-epoch-best/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "[MASK]", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}}
|
checkpoint-epoch-best/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fdc81e1fc9d42e0c08b86d5b280d05d7c5e9747c4231c648f2b56b8e1d893c82
|
3 |
+
size 845731
|
checkpoint-epoch-best/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sep_token": {"content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "[MASK]", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sp_model_kwargs": {}, "model_max_length": 4096, "name_or_path": "vasudevgupta/bigbird-roberta-natural-questions", "special_tokens_map_file": "/home/vasu/.cache/huggingface/transformers/400be7e354ea6eb77319bcc7fa34899ec9fa2e3aff0fa677f6eb7e45a01b1548.75b358ecb30fa6b001d9d87bfde336c02d9123e7a8f5b90cc890d0f6efc3d4a3", "tokenizer_file": null, "tokenizer_class": "BigBirdTokenizer"}
|