OOlajide commited on
Commit
4a26e59
1 Parent(s): 1ae1e57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
3
 
@@ -12,10 +13,20 @@ st.sidebar.header('What will you like to do?')
12
  option = st.sidebar.radio('', ['Extractive question answering', 'Text summarization', 'Text generation', 'Sentiment analysis'])
13
 
14
  @st.cache(show_spinner=False, allow_output_mutation=True)
15
- def question_model():
16
  tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
17
- question_answerer = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
18
- return question_answerer
 
 
 
 
 
 
 
 
 
 
19
 
20
  @st.cache(show_spinner=False, allow_output_mutation=True)
21
  def summarization_model():
@@ -41,10 +52,9 @@ if option == 'Extractive question answering':
41
  question = st.text_input(label='Enter your question')
42
  button = st.button('Get answer')
43
  if button:
44
- question_answerer = question_model()
45
  with st.spinner(text="Getting answer..."):
46
- answer = question_answerer(context=context, question=question)
47
- st.write(answer["answer"])
48
  elif source == "I want to upload a file":
49
  uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
50
  question = st.text_input(label='Enter your question')
 
1
+ import torch
2
  import streamlit as st
3
  from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
4
 
 
13
  option = st.sidebar.radio('', ['Extractive question answering', 'Text summarization', 'Text generation', 'Sentiment analysis'])
14
 
15
  @st.cache(show_spinner=False, allow_output_mutation=True)
16
+ def question_answerer(context, question):
17
  tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
18
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
19
+ inputs = tokenizer(question, context, add_special_tokens=True, return_tensors="pt")
20
+ input_ids = inputs["input_ids"].tolist()[0]
21
+ outputs = model(**inputs)
22
+ answer_start_scores = outputs.start_logits
23
+ answer_end_scores = outputs.end_logits
24
+ # Get the most likely beginning of answer with the argmax of the score
25
+ answer_start = torch.argmax(answer_start_scores)
26
+ # Get the most likely end of answer with the argmax of the score
27
+ answer_end = torch.argmax(answer_end_scores) + 1
28
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
29
+ return answer
30
 
31
  @st.cache(show_spinner=False, allow_output_mutation=True)
32
  def summarization_model():
 
52
  question = st.text_input(label='Enter your question')
53
  button = st.button('Get answer')
54
  if button:
 
55
  with st.spinner(text="Getting answer..."):
56
+ answer = question_answerer(context, question)
57
+ st.write(answer)
58
  elif source == "I want to upload a file":
59
  uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
60
  question = st.text_input(label='Enter your question')