|
import streamlit as st |
|
import torch |
|
st.title("TextPressoMachine") |
|
|
|
from transformers import AutoModelForSeq2SeqLM |
|
from t5_model import T5 |
|
from transformers import AutoTokenizer |
|
from transformers import pipeline |
|
models = { |
|
"T5 Small": "ZinebSN/t5_ckpt", |
|
"GPT2": "ZinebSN/GPT2_Summarier" |
|
} |
|
|
|
selected_model = st.radio("Select Model", list(models.keys())) |
|
model_name = models[selected_model] |
|
tokenizer = AutoTokenizer.from_pretrained('t5-small') |
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path='./t5_epoch9.ckpt' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
|
|
input_text=st.text_area("Input the text to summarize","") |
|
if st.button("Summarize"): |
|
st.text("It may take a minute or two.") |
|
nwords=len(input_text.split(" ")) |
|
text_input_ids=tokenizer('summarize: '+input_text, max_length=600, padding="max_length", truncation=True).input_ids |
|
output_ids = model.generate(torch.tensor(text_input_ids)) |
|
generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
st.header("Summary") |
|
st.markdown(generated_summary) |