Seungah Son commited on
Commit
28ef2f3
1 Parent(s): 5f82246

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import BertForQuestionAnswering
3
+ from transformers import BertTokenizerFast
4
+ import torch
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
+ model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model").to(device)
9
+
10
+ def get_prediction(context, question):
11
+ inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
12
+ outputs = model(**inputs)
13
+
14
+ answer_start = torch.argmax(outputs.start_logits)
15
+ answer_end = torch.argmax(outputs.end_logits) + 1
16
+
17
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
18
+
19
+ return answer
20
+
21
+ def question_answer(context, question):
22
+ prediction = get_prediction(context, question)
23
+ return prediction
24
+
25
+ iface = gr.Interface(
26
+ fn=question_answer,
27
+ inputs=[gr.Textbox("Context"), gr.Textbox("Question")],
28
+ outputs=gr.Textbox("Answer")
29
+ )
30
+
31
+ iface.launch()