OOlajide commited on
Commit
fae4d2b
1 Parent(s): 7f21271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import streamlit as st
3
- from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
4
 
5
  st.set_page_config(page_title="Common NLP Tasks")
6
  st.title("Common NLP Tasks")
@@ -14,19 +14,10 @@ option = st.sidebar.radio('', ['Extractive question answering', 'Text summarizat
14
 
15
  @st.cache(show_spinner=False, allow_output_mutation=True)
16
  def question_answerer(context, question):
17
- tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
18
- model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
19
- inputs = tokenizer(question, context, add_special_tokens=True, return_tensors="pt")
20
- input_ids = inputs["input_ids"].tolist()[0]
21
- outputs = model(**inputs)
22
- answer_start_scores = outputs.start_logits
23
- answer_end_scores = outputs.end_logits
24
- # Get the most likely beginning of answer with the argmax of the score
25
- answer_start = torch.argmax(answer_start_scores)
26
- # Get the most likely end of answer with the argmax of the score
27
- answer_end = torch.argmax(answer_end_scores) + 1
28
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
29
- return answer
30
 
31
  @st.cache(show_spinner=False, allow_output_mutation=True)
32
  def summarization_model():
 
1
  import torch
2
  import streamlit as st
3
+ from transformers import pipeline
4
 
5
  st.set_page_config(page_title="Common NLP Tasks")
6
  st.title("Common NLP Tasks")
 
14
 
15
  @st.cache(show_spinner=False, allow_output_mutation=True)
16
  def question_answerer(context, question):
17
+ model_name = "deepset/roberta-base-squad2"
18
+ pipeline(model=model_name, tokenizer=model_name, revision="v1.0", task="question-answering")
19
+ result = question_answerer(question="What is extractive question answering?", context=context)
20
+ return result['answer']
 
 
 
 
 
 
 
 
 
21
 
22
  @st.cache(show_spinner=False, allow_output_mutation=True)
23
  def summarization_model():