SummarizeAI_Pro / app.py
Lakshan2003's picture
Update app.py
021abb6 verified
import streamlit as st
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
"""Load the PEFT model and tokenizer once and cache them"""
base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
peft_model = PeftModel.from_pretrained(base_model, "Lakshan2003/finetuned-t5-xsum")
tokenizer = AutoTokenizer.from_pretrained("Lakshan2003/finetuned-t5-xsum")
return peft_model, tokenizer
def generate_summary(text, model, tokenizer, max_length=128, min_length=30):
"""Generate summary using the PEFT model"""
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Prepare the input text
prefix = "summarize: "
input_text = prefix + text
# Tokenize
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate summary
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True,
no_repeat_ngram_size=3
)
# Decode the summary
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return summary
def main():
st.set_page_config(
page_title="SummarizeAI Pro",
page_icon="✨",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.main-title {
text-align: center;
color: #1E88E5;
font-size: 3rem !important;
font-weight: 700;
margin-bottom: 1rem;
}
.subtitle {
text-align: center;
color: #424242;
font-size: 1.2rem !important;
margin-bottom: 2rem;
}
</style>
""", unsafe_allow_html=True)
# App title and subtitle
st.markdown("<h1 class='main-title'>✨ SummarizeAI Pro</h1>", unsafe_allow_html=True)
st.markdown("<p class='subtitle'>Transform lengthy text into concise, meaningful summaries with AI</p>",
unsafe_allow_html=True)
# Load model and tokenizer
with st.spinner("Loading model... (this may take a few moments)"):
model, tokenizer = load_model()
# Input text area
text = st.text_area(
"πŸ“ Enter your text below:",
height=200,
placeholder="Paste your text here and let SummarizeAI Pro work its magic..."
)
# Create three columns for better layout
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
max_length = st.slider("Maximum summary length", 50, 250, 128)
with col2:
min_length = st.slider("Minimum summary length", 10, 100, 30)
with col3:
st.markdown("<br>", unsafe_allow_html=True) # Spacing
generate_button = st.button("✨ Generate Summary", use_container_width=True)
if generate_button:
if text:
with st.spinner("✨ AI is crafting your summary..."):
try:
summary = generate_summary(text, model, tokenizer,
max_length=max_length,
min_length=min_length)
st.markdown("### πŸ“Š Summary Results")
# Create columns for statistics
stat_col1, stat_col2 = st.columns(2)
with stat_col1:
st.info(f"πŸ“„ Original text: {len(text.split())} words")
with stat_col2:
st.info(f"βœ‚οΈ Summarized text: {len(summary.split())} words")
# Display summary in a nice box
st.markdown("### ✨ Generated Summary")
st.markdown(f"""
<div style="
padding: 20px;
border-radius: 10px;
background-color: #f0f2f6;
border-left: 5px solid #1E88E5;
">
{summary}
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"🚫 An error occurred: {str(e)}")
else:
st.warning("⚠️ Please enter some text to summarize.")
# Sidebar with enhanced styling
st.sidebar.markdown("## 🎯 About SummarizeAI Pro")
st.sidebar.markdown("""
SummarizeAI Pro uses advanced AI technology powered by a PEFT-tuned T5 model
to generate accurate and concise summaries while preserving the key points
of your text.
""")
st.sidebar.markdown("## πŸ“– How to Use")
st.sidebar.markdown("""
1. πŸ“ Paste your text in the input box
2. 🎚️ Adjust summary length with sliders
3. πŸš€ Click 'Generate Summary'
4. ✨ Get your AI-powered summary
""")
# Footer
st.markdown("""
<div style='text-align: center; color: #666; padding: 20px;'>
<p>Made with ❀️ by Lakshan Cooray</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()