T5_Small1 / app.py
ZinebSN's picture
Update app.py
01b15b8
import streamlit as st
import torch
st.title("TextPressoMachine")
from transformers import AutoModelForSeq2SeqLM
from t5_model import T5
from transformers import AutoTokenizer
from transformers import pipeline
models = {
"T5 Small": "ZinebSN/t5_ckpt",
"GPT2": "ZinebSN/GPT2_Summarier"
}
selected_model = st.radio("Select Model", list(models.keys()))
model_name = models[selected_model]
tokenizer = AutoTokenizer.from_pretrained('t5-small')
#model = torch.load(model_name+'/model.pt')
#if model_name == "T5 Small":
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#model = torch.load(model_name+'/model.pt')
checkpoint_path='./t5_epoch9.ckpt'
# Choose the appropriate device
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model_state_dict = torch.load(checkpoint_path, map_location=device)
#model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
#model.load_state_dict(model_state_dict)
model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
#else:
#model = GPT2().from_pretrained(model_name)
#tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text=st.text_area("Input the text to summarize","")
if st.button("Summarize"):
st.text("It may take a minute or two.")
nwords=len(input_text.split(" "))
text_input_ids=tokenizer('summarize: '+input_text, max_length=600, padding="max_length", truncation=True).input_ids
output_ids = model.generate(torch.tensor(text_input_ids))
generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
st.header("Summary")
st.markdown(generated_summary)