ziyadbastaili commited on
Commit
7c96c8f
1 Parent(s): cac6dc4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import time
4
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
5
+ from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, squad_convert_examples_to_features)
6
+ from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
7
+ from transformers.data.metrics.squad_metrics import compute_predictions_logits
8
+
9
+ model_name_or_path = "ktrapeznikov/albert-xlarge-v2-squad-v2"
10
+
11
+ output_dir = ""
12
+
13
+ # Config
14
+ n_best_size = 1
15
+ max_answer_length = 30
16
+ do_lower_case = True
17
+ null_score_diff_threshold = 0.0
18
+
19
+ def to_list(tensor):
20
+ return tensor.detach().cpu().tolist()
21
+
22
+ # Setup model
23
+ config_class, model_class, tokenizer_class = (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
24
+ config = config_class.from_pretrained(model_name_or_path)
25
+ tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
26
+ model = model_class.from_pretrained(model_name_or_path, config=config)
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
+ processor = SquadV2Processor()
30
+
31
+ def run_prediction(question, context_text):
32
+ """Setup function to compute predictions"""
33
+ examples = []
34
+ question_texts = [question]
35
+ for i, question_text in enumerate(question_texts):
36
+ example = SquadExample(
37
+ qas_id=str(i),
38
+ question_text=question_text,
39
+ context_text=context_text,
40
+ answer_text=None,
41
+ start_position_character=None,
42
+ title="Predict",
43
+ is_impossible=False,
44
+ answers=None,
45
+ )
46
+
47
+ examples.append(example)
48
+
49
+ features, dataset = squad_convert_examples_to_features(
50
+ examples=examples,
51
+ tokenizer=tokenizer,
52
+ max_seq_length=384,
53
+ doc_stride=128,
54
+ max_query_length=64,
55
+ is_training=False,
56
+ return_dataset="pt",
57
+ threads=1,
58
+ )
59
+
60
+ eval_sampler = SequentialSampler(dataset)
61
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
62
+
63
+ all_results = []
64
+
65
+ for batch in eval_dataloader:
66
+ model.eval()
67
+ batch = tuple(t.to(device) for t in batch)
68
+
69
+ with torch.no_grad():
70
+ inputs = {
71
+ "input_ids": batch[0],
72
+ "attention_mask": batch[1],
73
+ "token_type_ids": batch[2],
74
+ }
75
+
76
+ example_indices = batch[3]
77
+
78
+ outputs = model(**inputs)
79
+
80
+ for i, example_index in enumerate(example_indices):
81
+ eval_feature = features[example_index.item()]
82
+ unique_id = int(eval_feature.unique_id)
83
+
84
+ output = [to_list(output[i]) for output in outputs]
85
+
86
+ start_logits, end_logits = output
87
+ result = SquadResult(unique_id, start_logits, end_logits)
88
+ all_results.append(result)
89
+
90
+ predictions = compute_predictions_logits(
91
+ examples,
92
+ features,
93
+ all_results,
94
+ n_best_size,
95
+ max_answer_length,
96
+ do_lower_case,
97
+ False,
98
+ False,
99
+ False,
100
+ False, # verbose_logging
101
+ True, # version_2_with_negative
102
+ null_score_diff_threshold,
103
+ tokenizer,
104
+ )
105
+ answer = "empty"
106
+ for key in predictions.keys():
107
+ answer=predictions[key]
108
+ break
109
+ return answer
110
+
111
+
112
+ context = "4/5/2022 · In connection with the closing, Helix Acquisition Corp changed its name to MoonLake Immunotherapeutics (“MoonLake” or the “Company”). Beginning April 6, 2022, MoonLake’s shares will trade on the Nasdaq Stock Market..."
113
+ questions = ["Helix Acquisition Corp change its name to"]
114
+ title = 'Question Answering demo with Albert QA transformer and gradio'
115
+
116
+ # Run method
117
+
118
+
119
+
120
+ gr.Interface(run_prediction,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")],
121
+ outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch()