import torch import streamlit as st from transformers import PegasusForConditionalGeneration, AutoTokenizer @st.cache(allow_output_mutation=True) model_name = 'google/pegasus-cnn_dailymail' device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = AutoTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) def summarize(passage): txt = " ".join(passage) batch = tokenizer(txt, truncation=True, padding='longest', return_tensors="pt").to(device) translated = model.generate(**batch) summy = tokenizer.batch_decode(translated, skip_special_tokens=True) return summy