import streamlit as st import torch st.title("TextPressoMachine") from transformers import AutoModelForSeq2SeqLM from t5_model import T5 from transformers import AutoTokenizer from transformers import pipeline models = { "T5 Small": "ZinebSN/t5_ckpt", "GPT2": "ZinebSN/GPT2_Summarier" } selected_model = st.radio("Select Model", list(models.keys())) model_name = models[selected_model] tokenizer = AutoTokenizer.from_pretrained('t5-small') #model = torch.load(model_name+'/model.pt') #if model_name == "T5 Small": #model = AutoModelForSeq2SeqLM.from_pretrained(model_name) #model = torch.load(model_name+'/model.pt') checkpoint_path='./t5_epoch9.ckpt' # Choose the appropriate device #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #model_state_dict = torch.load(checkpoint_path, map_location=device) #model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") #model.load_state_dict(model_state_dict) model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu')) #else: #model = GPT2().from_pretrained(model_name) #tokenizer = AutoTokenizer.from_pretrained(model_name) input_text=st.text_area("Input the text to summarize","") if st.button("Summarize"): st.text("It may take a minute or two.") nwords=len(input_text.split(" ")) text_input_ids=tokenizer('summarize: '+input_text, max_length=600, padding="max_length", truncation=True).input_ids output_ids = model.generate(torch.tensor(text_input_ids)) generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) st.header("Summary") st.markdown(generated_summary)