|
import streamlit as st |
|
from transformers import T5TokenizerFast, T5ForConditionalGeneration |
|
import nltk |
|
import math |
|
import torch |
|
|
|
model_name = "abokbot/t5-end2end-questions-generation" |
|
|
|
st.header("Generate questions for short Wikipedia-like articles") |
|
|
|
st_model_load = st.text('Loading question generator model...') |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
print("Loading model...") |
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base") |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
nltk.download('punkt') |
|
print("Model loaded!") |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
st.success('Model loaded!') |
|
st_model_load.text("") |
|
|
|
if 'text' not in st.session_state: |
|
st.session_state.text = "" |
|
st_text_area = st.text_area('Text to generate the questions for', value=st.session_state.text, height=500) |
|
|
|
def generate_questions(): |
|
st.session_state.text = st_text_area |
|
|
|
generator_args = { |
|
"max_length": 256, |
|
"num_beams": 4, |
|
"length_penalty": 1.5, |
|
"no_repeat_ngram_size": 3, |
|
"early_stopping": True, |
|
} |
|
input_string = "generate questions: " + st_text_area + " </s>" |
|
input_ids = tokenizer.encode(input_string, return_tensors="pt") |
|
res = model.generate(input_ids, **generator_args) |
|
output = tokenizer.batch_decode(res, skip_special_tokens=True) |
|
output = [question.strip() + "?" for question in output[0].split("?") if question != ""] |
|
|
|
st.session_state.questions = output |
|
|
|
|
|
st_generate_button = st.button('Generate questions', on_click=generate_questions) |
|
|
|
|
|
if 'questions' not in st.session_state: |
|
st.session_state.questions = [] |
|
|
|
if len(st.session_state.questions) > 0: |
|
with st.container(): |
|
st.subheader("Generated questions") |
|
for title in st.session_state.questions: |
|
st.markdown("__" + title + "__") |
|
|