Spaces:
Running
Running
import streamlit as st | |
from peft import PeftModel | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
def load_model(): | |
"""Load the PEFT model and tokenizer once and cache them""" | |
base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") | |
peft_model = PeftModel.from_pretrained(base_model, "Lakshan2003/finetuned-t5-xsum") | |
tokenizer = AutoTokenizer.from_pretrained("Lakshan2003/finetuned-t5-xsum") | |
return peft_model, tokenizer | |
def generate_summary(text, model, tokenizer, max_length=128, min_length=30): | |
"""Generate summary using the PEFT model""" | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
# Prepare the input text | |
prefix = "summarize: " | |
input_text = prefix + text | |
# Tokenize | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Generate summary | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=max_length, | |
min_length=min_length, | |
num_beams=4, | |
length_penalty=2.0, | |
early_stopping=True, | |
no_repeat_ngram_size=3 | |
) | |
# Decode the summary | |
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return summary | |
def main(): | |
st.set_page_config( | |
page_title="SummarizeAI Pro", | |
page_icon="β¨", | |
layout="wide" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main-title { | |
text-align: center; | |
color: #1E88E5; | |
font-size: 3rem !important; | |
font-weight: 700; | |
margin-bottom: 1rem; | |
} | |
.subtitle { | |
text-align: center; | |
color: #424242; | |
font-size: 1.2rem !important; | |
margin-bottom: 2rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# App title and subtitle | |
st.markdown("<h1 class='main-title'>β¨ SummarizeAI Pro</h1>", unsafe_allow_html=True) | |
st.markdown("<p class='subtitle'>Transform lengthy text into concise, meaningful summaries with AI</p>", | |
unsafe_allow_html=True) | |
# Load model and tokenizer | |
with st.spinner("Loading model... (this may take a few moments)"): | |
model, tokenizer = load_model() | |
# Input text area | |
text = st.text_area( | |
"π Enter your text below:", | |
height=200, | |
placeholder="Paste your text here and let SummarizeAI Pro work its magic..." | |
) | |
# Create three columns for better layout | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
with col1: | |
max_length = st.slider("Maximum summary length", 50, 250, 128) | |
with col2: | |
min_length = st.slider("Minimum summary length", 10, 100, 30) | |
with col3: | |
st.markdown("<br>", unsafe_allow_html=True) # Spacing | |
generate_button = st.button("β¨ Generate Summary", use_container_width=True) | |
if generate_button: | |
if text: | |
with st.spinner("β¨ AI is crafting your summary..."): | |
try: | |
summary = generate_summary(text, model, tokenizer, | |
max_length=max_length, | |
min_length=min_length) | |
st.markdown("### π Summary Results") | |
# Create columns for statistics | |
stat_col1, stat_col2 = st.columns(2) | |
with stat_col1: | |
st.info(f"π Original text: {len(text.split())} words") | |
with stat_col2: | |
st.info(f"βοΈ Summarized text: {len(summary.split())} words") | |
# Display summary in a nice box | |
st.markdown("### β¨ Generated Summary") | |
st.markdown(f""" | |
<div style=" | |
padding: 20px; | |
border-radius: 10px; | |
background-color: #f0f2f6; | |
border-left: 5px solid #1E88E5; | |
"> | |
{summary} | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"π« An error occurred: {str(e)}") | |
else: | |
st.warning("β οΈ Please enter some text to summarize.") | |
# Sidebar with enhanced styling | |
st.sidebar.markdown("## π― About SummarizeAI Pro") | |
st.sidebar.markdown(""" | |
SummarizeAI Pro uses advanced AI technology powered by a PEFT-tuned T5 model | |
to generate accurate and concise summaries while preserving the key points | |
of your text. | |
""") | |
st.sidebar.markdown("## π How to Use") | |
st.sidebar.markdown(""" | |
1. π Paste your text in the input box | |
2. ποΈ Adjust summary length with sliders | |
3. π Click 'Generate Summary' | |
4. β¨ Get your AI-powered summary | |
""") | |
# Footer | |
st.markdown(""" | |
<div style='text-align: center; color: #666; padding: 20px;'> | |
<p>Made with β€οΈ by Lakshan Cooray</p> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |