import torch import streamlit as st from transformers import PegasusForConditionalGeneration, AutoTokenizer @st.cache(allow_output_mutation=True) def do_summary(model_name): model = PegasusForConditionalGeneration.from_pretrained(model_name) return model @st.cache(allow_output_mutation=True) def do_tokenize(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer model = do_summary("google/pegasus-cnn_dailymail") tokenizer = do_tokenize("google/pegasus-cnn_dailymail") def summarize(passage): txt = " ".join(passage) #model_name = 'google/pegasus-cnn_dailymail' device = 'cuda' if torch.cuda.is_available() else 'cpu' #tokenizer = AutoTokenizer.from_pretrained(model_name) #model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) batch = tokenizer(txt, truncation=True, padding='longest', return_tensors="pt").to(device) translated = model.generate(**batch) summy = tokenizer.batch_decode(translated, skip_special_tokens=True) print("summ end") return summy