CountingMstar commited on
Commit
9a229a7
1 Parent(s): 7d3498d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -4
app.py CHANGED
@@ -1,10 +1,65 @@
1
  import gradio as gr
2
- # from transformers import BertForQuestionAnswering
3
 
4
- # model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
5
 
6
- def greet(name):
7
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
10
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import BertForQuestionAnswering
3
 
4
+ model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
5
 
6
+ def get_prediction(context, question):
7
+ inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
8
+ outputs = model(**inputs)
9
+
10
+ answer_start = torch.argmax(outputs[0])
11
+ answer_end = torch.argmax(outputs[1]) + 1
12
+
13
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
14
+
15
+ return answer
16
+
17
+ def normalize_text(s):
18
+ """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
19
+ import string, re
20
+ def remove_articles(text):
21
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
22
+ return re.sub(regex, " ", text)
23
+ def white_space_fix(text):
24
+ return " ".join(text.split())
25
+ def remove_punc(text):
26
+ exclude = set(string.punctuation)
27
+ return "".join(ch for ch in text if ch not in exclude)
28
+ def lower(text):
29
+ return text.lower()
30
+
31
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
32
+
33
+ def exact_match(prediction, truth):
34
+ return bool(normalize_text(prediction) == normalize_text(truth))
35
+
36
+ def compute_f1(prediction, truth):
37
+ pred_tokens = normalize_text(prediction).split()
38
+ truth_tokens = normalize_text(truth).split()
39
+
40
+ # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
41
+ if len(pred_tokens) == 0 or len(truth_tokens) == 0:
42
+ return int(pred_tokens == truth_tokens)
43
+
44
+ common_tokens = set(pred_tokens) & set(truth_tokens)
45
+
46
+ # if there are no common tokens then f1 = 0
47
+ if len(common_tokens) == 0:
48
+ return 0
49
+
50
+ prec = len(common_tokens) / len(pred_tokens)
51
+ rec = len(common_tokens) / len(truth_tokens)
52
+
53
+ return round(2 * (prec * rec) / (prec + rec), 2)
54
+
55
+ def question_answer(context, question):
56
+ prediction = get_prediction(context,question)
57
+ return prediction
58
+
59
+ def greet(texts):
60
+ # for question, answer in texts:
61
+ # question_answer(context, question)
62
+ return texts
63
 
64
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
65
  iface.launch()