brian-challenge / app.py
Christian Koch
further improvements, implement question generator
cd3659c
raw history blame
No virus
3.83 kB
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
from fill_in_summary import FillInSummary
from paraphrase import PegasusParaphraser
import question_gen as q
default_text = "Apple was founded as Apple Computer Company on April 1, 1976, by Steve Jobs, Steve Wozniak and Ronald " \
"Wayne to develop and sell Wozniak's Apple I personal computer. It was incorporated by Jobs and " \
"Wozniak as Apple Computer, Inc. in 1977 and the company's next computer, the Apple II became a best " \
"seller. Apple went public in 1980, to instant financial success. The company went onto develop new " \
"computers featuring innovative graphical user interfaces, including the original Macintosh, " \
"announced in a critically acclaimed advertisement, '1984', directed by Ridley Scott. By 1985, " \
"the high cost of its products and power struggles between executives caused problems. Wozniak stepped " \
"back from Apple amicably, while Jobs resigned to found NeXT, taking some Apple employees with him. "
default_text2 = "The board of directors instructed Sculley to contain Jobs and his ability to launch expensive forays " \
"into untested products "
st.set_page_config(layout="centered")
st.title('Question Generator by Eddevs')
st.write('Please select the task you want to do.')
select = st.selectbox('Type', ['Question Generator', 'Paraphrasing', 'Summarization', 'Fill in the blank'])
if select == "Question Generator":
with st.form("question_gen"):
left_column, right_column = st.columns(2)
num_seq = left_column.slider('Question Count', 0, 10, 3)
beams = right_column.slider('Beams', 0, 10, 5)
max_length = st.slider('Max Length', 0, 1024, 300)
text_input = st.text_area("Input Text", value=default_text)
submitted = st.form_submit_button("Generate")
if submitted:
with st.spinner('Wait for it...'):
question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
question_tokenizer = T5Tokenizer.from_pretrained('ramsrigouthamg/t5_squad_v1')
result = q.get_question(text_input, "", question_model, question_tokenizer, num_seq, beams, max_length)
st.write(result)
elif select == "Summarization":
with st.form("summarization"):
text_input = st.text_area("Input Text", value=default_text)
submitted = st.form_submit_button("Generate")
if submitted:
with st.spinner('Wait for it...'):
result = FillInSummary().summarize(text_input)
st.write(text_input)
elif select == "Fill in the blank":
with st.form("fill_in_the_blank"):
text_input = st.text_area("Input Text", value=default_text)
submitted = st.form_submit_button("Generate")
if submitted:
with st.spinner('Wait for it...'):
fill = FillInSummary()
result = fill.summarize(text_input)
result = fill.blank_ne_out(result)
st.write(result)
elif select == "Paraphrasing":
with st.form("paraphrasing"):
left_column, right_column = st.columns(2)
count = left_column.slider('Count', 0, 10, 3)
temperature = right_column.slider('Temperature', 0.0, 10.0, 1.5)
text_input = st.text_area("Input Text", value=default_text2)
submitted = st.form_submit_button("Generate")
if submitted:
with st.spinner('Wait for it...'):
paraphrase_model = PegasusParaphraser(num_return_sequences=count, temperature=temperature)
result = paraphrase_model.paraphrase(text_input)
st.write(result)