czq commited on
Commit
9a92410
1 Parent(s): 0868831
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用gradio开发QA的可视化demo
2
+
3
+ import gradio as gr
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer
5
+ import torch
6
+ from torch import nn
7
+ from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate
8
+
9
+ class BigBirdNullHead(nn.Module):
10
+ """Head for question answering tasks."""
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
15
+ self.intermediate = BigBirdIntermediate(config)
16
+ self.output = BigBirdOutput(config)
17
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
18
+
19
+ def forward(self, encoder_output):
20
+ hidden_states = self.dropout(encoder_output)
21
+ hidden_states = self.intermediate(hidden_states)
22
+ hidden_states = self.output(hidden_states, encoder_output)
23
+ logits = self.qa_outputs(hidden_states)
24
+ return logits
25
+
26
+
27
+ model_path = '/data1/chenzq/demo/checkpoint-epoch-best'
28
+
29
+ class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
30
+ def __init__(self, config, model_id):
31
+ super().__init__(config)
32
+ self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id,
33
+ config=self.config, add_pooling_layer=True)
34
+ self.null_classifier = BigBirdNullHead(self.bertqa.config)
35
+ self.contrastive_mlp = nn.Sequential(
36
+ nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size),
37
+ )
38
+
39
+ def forward(self, **kwargs):
40
+ if self.training:
41
+ null_labels = kwargs['is_impossible']
42
+ del kwargs['is_impossible']
43
+ outputs = self.bertqa(**kwargs)
44
+ pooler_output = outputs.pooler_output
45
+ null_logits = self.null_classifier(pooler_output)
46
+ loss_fct = nn.CrossEntropyLoss()
47
+ null_loss = loss_fct(null_logits, null_labels)
48
+
49
+
50
+ outputs.loss = outputs.loss + null_loss
51
+
52
+ return outputs.to_tuple()
53
+ else:
54
+ outputs = self.bertqa(**kwargs)
55
+ pooler_output = outputs.pooler_output
56
+ null_logits = self.null_classifier(pooler_output)
57
+
58
+ return (outputs.start_logits, outputs.end_logits, null_logits)
59
+ model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
60
+ config = BigBirdConfig.from_pretrained(model_id)
61
+ model = BigBirdForQuestionAnsweringWithNull(config, model_id)
62
+ model.to('cuda')
63
+ model.eval()
64
+
65
+ model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
66
+
67
+ tokenizer = BigBirdTokenizer.from_pretrained(model_path)
68
+
69
+ def main(question, context):
70
+ # 编码输入
71
+ text = question + " [SEP] " + context
72
+ inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
73
+ inputs.to('cuda')
74
+ # 预测答案
75
+ outputs = model(**inputs)
76
+ start_scores = outputs[0]
77
+ end_scores = outputs[1]
78
+ null_scores = outputs[2]
79
+ # 解码答案
80
+ is_impossible = null_scores.argmax().item()
81
+ if is_impossible:
82
+ return "No Answer"
83
+ else:
84
+ answer_start = torch.argmax(start_scores)
85
+ answer_end = torch.argmax(end_scores) + 1
86
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
87
+ return answer
88
+
89
+
90
+
91
+ with gr.Blocks() as demo:
92
+
93
+ gr.Markdown("""# Question Answerer!""")
94
+ with gr.Row():
95
+ with gr.Column():
96
+ # options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model")
97
+ text1 = gr.Textbox(
98
+ label="Question",
99
+ lines=1,
100
+ value="Who does Cristiano Ronaldo play for?",
101
+ )
102
+ text2 = gr.Textbox(
103
+ label="Context",
104
+ lines=3,
105
+ value="Cristiano Ronaldo is a player for Manchester United",
106
+ )
107
+ output = gr.Textbox()
108
+ b1 = gr.Button("Ask Question!")
109
+ b1.click(main, inputs=[text1, text2], outputs=output)
110
+ # gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""")
111
+
112
+
113
+ if __name__ == "__main__":
114
+
115
+ demo.launch(share=True)
checkpoint-epoch-best/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "vasudevgupta/bigbird-roberta-natural-questions",
3
+ "architectures": [
4
+ "BigBirdForQuestionAnsweringWithNull"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "attention_type": "block_sparse",
8
+ "block_size": 64,
9
+ "bos_token_id": 1,
10
+ "classifier_dropout": null,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "gelu_fast",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 4096,
19
+ "model_type": "big_bird",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "num_random_blocks": 3,
23
+ "pad_token_id": 0,
24
+ "position_embedding_type": "absolute",
25
+ "rescale_embeddings": false,
26
+ "sep_token_id": 66,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.18.0",
29
+ "type_vocab_size": 2,
30
+ "use_bias": true,
31
+ "use_cache": true,
32
+ "vocab_size": 50358
33
+ }
checkpoint-epoch-best/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f56923a86a3827a575c0ba614808b6a7b58cf76a631f0e8f5529cc73583973a6
3
+ size 550147981
checkpoint-epoch-best/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "[MASK]", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}}
checkpoint-epoch-best/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdc81e1fc9d42e0c08b86d5b280d05d7c5e9747c4231c648f2b56b8e1d893c82
3
+ size 845731
checkpoint-epoch-best/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sep_token": {"content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "[MASK]", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sp_model_kwargs": {}, "model_max_length": 4096, "name_or_path": "vasudevgupta/bigbird-roberta-natural-questions", "special_tokens_map_file": "/home/vasu/.cache/huggingface/transformers/400be7e354ea6eb77319bcc7fa34899ec9fa2e3aff0fa677f6eb7e45a01b1548.75b358ecb30fa6b001d9d87bfde336c02d9123e7a8f5b90cc890d0f6efc3d4a3", "tokenizer_file": null, "tokenizer_class": "BigBirdTokenizer"}