CountingMstar commited on
Commit
32a4ae2
1 Parent(s): 89057d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -10,20 +10,20 @@ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
10
  model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model")
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # def get_prediction(context, question):
14
- # inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
15
- # outputs = model(**inputs)
16
 
17
- # answer_start = torch.argmax(outputs[0])
18
- # answer_end = torch.argmax(outputs[1]) + 1
19
 
20
- # answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
21
 
22
- # return answer
23
 
24
- # def question_answer(context, question):
25
- # prediction = get_prediction(context,question)
26
- # return prediction
27
 
28
  def split(text):
29
  context, question = '', ''
@@ -44,14 +44,14 @@ def split(text):
44
 
45
  return context[:-2], question[1:]
46
 
47
- # def greet(texts):
48
- # context, question = split(texts)
49
- # answer = question_answer(context, question)
50
- # return answer
51
- def greet(text):
52
- context, question = split(text)
53
- # answer = question_answer(context, question)
54
- return context
55
 
56
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
57
  iface.launch()
 
10
  model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model")
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
+ def get_prediction(context, question):
14
+ inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
15
+ outputs = model(**inputs)
16
 
17
+ answer_start = torch.argmax(outputs[0])
18
+ answer_end = torch.argmax(outputs[1]) + 1
19
 
20
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
21
 
22
+ return answer
23
 
24
+ def question_answer(context, question):
25
+ prediction = get_prediction(context,question)
26
+ return prediction
27
 
28
  def split(text):
29
  context, question = '', ''
 
44
 
45
  return context[:-2], question[1:]
46
 
47
+ def greet(texts):
48
+ context, question = split(texts)
49
+ answer = question_answer(context, question)
50
+ return answer
51
+ # def greet(text):
52
+ # context, question = split(text)
53
+ # # answer = question_answer(context, question)
54
+ # return context
55
 
56
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
57
  iface.launch()