import streamlit as st from peft import PeftModel from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch @st.cache_resource def load_model(): """Load the PEFT model and tokenizer once and cache them""" base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") peft_model = PeftModel.from_pretrained(base_model, "Lakshan2003/finetuned-t5-xsum") tokenizer = AutoTokenizer.from_pretrained("Lakshan2003/finetuned-t5-xsum") return peft_model, tokenizer def generate_summary(text, model, tokenizer, max_length=128, min_length=30): """Generate summary using the PEFT model""" # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) # Prepare the input text prefix = "summarize: " input_text = prefix + text # Tokenize inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate summary with torch.no_grad(): output_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=max_length, min_length=min_length, num_beams=4, length_penalty=2.0, early_stopping=True, no_repeat_ngram_size=3 ) # Decode the summary summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) return summary def main(): st.set_page_config( page_title="SummarizeAI Pro", page_icon="✨", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # App title and subtitle st.markdown("
Transform lengthy text into concise, meaningful summaries with AI
", unsafe_allow_html=True) # Load model and tokenizer with st.spinner("Loading model... (this may take a few moments)"): model, tokenizer = load_model() # Input text area text = st.text_area( "📝 Enter your text below:", height=200, placeholder="Paste your text here and let SummarizeAI Pro work its magic..." ) # Create three columns for better layout col1, col2, col3 = st.columns([1, 1, 1]) with col1: max_length = st.slider("Maximum summary length", 50, 250, 128) with col2: min_length = st.slider("Minimum summary length", 10, 100, 30) with col3: st.markdown("Made with ❤️ by Lakshan Cooray