sumith27's picture
Update app.py
37bdc7f verified
import streamlit as st
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
import os
import tempfile
import time
import nltk
import re
from bs4 import BeautifulSoup
import requests
import PyPDF2
from docx import Document
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
# Set page configuration
st.set_page_config(
page_title="Document Summarizer",
page_icon="πŸ“„",
layout="wide"
)
# Download NLTK data at startup
@st.cache_resource
def download_nltk_data():
nltk.download('punkt')
nltk.download('stopwords')
return True
# Initialize NLTK
nltk_initialized = download_nltk_data()
# Custom CSS
st.markdown("""
<style>
.main-header {font-size: 2rem; font-weight: bold; margin-bottom: 1rem;}
.summary-box {background-color: #e3e8f0; padding: 1rem; border-radius: 0.5rem; margin-top: 1rem; border: 1px solid #b0bec5; color: #263238;}
.keyword-chip {background-color: #0277bd; padding: 5px 10px; margin: 5px; border-radius: 15px; display: inline-block; color: white; font-weight: 500;}
.section-header {color: #01579b; font-weight: bold; margin-top: 1.5rem; margin-bottom: 0.5rem;}
</style>
""", unsafe_allow_html=True)
# Document processing functions
class DocumentProcessor:
def __init__(self):
self.stopwords = set(stopwords.words('english'))
def extract_text_from_pdf(self, file_bytes):
"""Extract text from PDF bytes"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_file:
temp_file.write(file_bytes)
temp_path = temp_file.name
text = ""
with open(temp_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() + "\n"
# Clean up temp file
os.unlink(temp_path)
return self.clean_text(text)
except Exception as e:
st.error(f"Error extracting text from PDF: {str(e)}")
return ""
def extract_text_from_docx(self, file_bytes):
"""Extract text from DOCX bytes"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_file:
temp_file.write(file_bytes)
temp_path = temp_file.name
doc = Document(temp_path)
text = "\n".join([para.text for para in doc.paragraphs])
# Clean up temp file
os.unlink(temp_path)
return self.clean_text(text)
except Exception as e:
st.error(f"Error extracting text from DOCX: {str(e)}")
return ""
def extract_text_from_txt(self, file_bytes):
"""Extract text from TXT bytes"""
try:
text = file_bytes.decode('utf-8')
return self.clean_text(text)
except UnicodeDecodeError:
try:
text = file_bytes.decode('latin-1')
return self.clean_text(text)
except Exception as e:
st.error(f"Error decoding text file: {str(e)}")
return ""
def extract_text_from_url(self, url):
"""Extract text from URL"""
try:
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.text, 'html.parser')
# Remove non-content elements
for tag in soup(['script', 'style', 'header', 'footer', 'nav']):
tag.decompose()
# Get main content
paragraphs = soup.find_all('p')
text = ' '.join([p.get_text().strip() for p in paragraphs])
return self.clean_text(text)
except Exception as e:
st.error(f"Error extracting text from URL: {str(e)}")
return ""
def clean_text(self, text):
"""Clean extracted text"""
# Replace multiple newlines and spaces
text = re.sub(r'\n+', '\n', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def segment_document(self, text, max_length=1000):
"""Break document into segments"""
if not text:
return []
sentences = sent_tokenize(text)
segments = []
current_segment = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
# If adding this sentence exceeds max length, start a new segment
if current_length + sentence_length > max_length and current_segment:
segments.append(' '.join(current_segment))
current_segment = []
current_length = 0
current_segment.append(sentence)
current_length += sentence_length
# Add the last segment
if current_segment:
segments.append(' '.join(current_segment))
return segments
def extract_keywords(self, text, top_n=8):
"""Extract keywords from text"""
# Simple frequency-based extraction
words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
words = [word for word in words if word not in self.stopwords]
# Count frequencies
word_freq = {}
for word in words:
word_freq[word] = word_freq.get(word, 0) + 1
# Sort by frequency
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
keywords = [word for word, _ in sorted_words[:top_n]]
return keywords
def process_document(self, file=None, url=None, text=None):
"""Process document from various sources"""
document_text = ""
if file:
file_type = file.type
file_bytes = file.getvalue()
if "pdf" in file_type:
document_text = self.extract_text_from_pdf(file_bytes)
elif "word" in file_type or file.name.endswith('.docx'):
document_text = self.extract_text_from_docx(file_bytes)
elif "text" in file_type or file.name.endswith('.txt'):
document_text = self.extract_text_from_txt(file_bytes)
else:
st.error(f"Unsupported file type: {file_type}")
return None
elif url:
document_text = self.extract_text_from_url(url)
elif text:
document_text = text
if not document_text:
st.error("Could not extract text from the document")
return None
# Process the document
segments = self.segment_document(document_text)
keywords = self.extract_keywords(document_text)
return {
"full_text": document_text,
"segments": segments,
"keywords": keywords
}
# Model loading - with smaller model for CPU
@st.cache_resource
def load_model():
# Load a smaller model better suited for CPU
model_name = "facebook/bart-large-cnn" # Could use smaller model if needed
# Make sure we use low precision for CPU
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
return tokenizer, model
def generate_summary(text, tokenizer, model, max_length=150, min_length=40):
"""Generate summary with BART"""
if not text or len(text.strip()) == 0:
return ""
# Tokenize with truncation
inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt")
# Generate summary with less beams for CPU
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
num_beams=2, # Reduced for CPU
min_length=min_length,
max_length=max_length,
length_penalty=2.0,
early_stopping=True,
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
def main():
st.markdown("<div class='main-header'>πŸ“„ Document Summarizer</div>", unsafe_allow_html=True)
st.write("Upload a document, enter a URL, or paste text to get an AI-powered summary")
# Initialize processor
processor = DocumentProcessor()
# Input tabs
tab1, tab2, tab3 = st.tabs(["πŸ“‚ Upload", "πŸ”— URL", "✏️ Text"])
with tab1:
uploaded_file = st.file_uploader("Upload a document", type=["pdf", "docx", "txt"])
if uploaded_file:
st.session_state.input_type = "file"
st.session_state.file = uploaded_file
with tab2:
url = st.text_input("Enter a webpage URL")
if url:
st.session_state.input_type = "url"
st.session_state.url = url
with tab3:
text = st.text_area("Paste your text", height=200)
if text:
st.session_state.input_type = "text"
st.session_state.text = text
# Process button
if st.button("Generate Summary", type="primary"):
with st.spinner("Processing document..."):
# Get input based on selected tab
doc_input = None
if hasattr(st.session_state, 'input_type'):
if st.session_state.input_type == "file" and hasattr(st.session_state, 'file'):
doc_input = processor.process_document(file=st.session_state.file)
elif st.session_state.input_type == "url" and hasattr(st.session_state, 'url'):
doc_input = processor.process_document(url=st.session_state.url)
elif st.session_state.input_type == "text" and hasattr(st.session_state, 'text'):
doc_input = processor.process_document(text=st.session_state.text)
if not doc_input:
st.error("Please provide a document, URL, or text")
return
# Load model (cached)
try:
tokenizer, model = load_model()
st.session_state.model_loaded = True
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Process segments
progress_bar = st.progress(0)
segments = doc_input["segments"]
segment_summaries = []
for i, segment in enumerate(segments):
# Update progress
progress = (i + 1) / max(1, len(segments))
progress_bar.progress(progress)
if len(segment.split()) > 30:
# Summarize segment
segment_summary = generate_summary(
segment, tokenizer, model,
max_length=150, min_length=40
)
segment_summaries.append(segment_summary)
else:
# Short segment, use as is
segment_summaries.append(segment)
# Create final summary
combined_text = " ".join(segment_summaries)
if len(combined_text.split()) > 100:
final_summary = generate_summary(
combined_text, tokenizer, model,
max_length=250, min_length=100
)
else:
final_summary = combined_text
# Calculate stats
original_words = len(doc_input["full_text"].split())
summary_words = len(final_summary.split())
compression = (1 - (summary_words / original_words)) * 100 if original_words > 0 else 0
# Save results
st.session_state.result = {
"final_summary": final_summary,
"segment_summaries": segment_summaries,
"keywords": doc_input["keywords"],
"stats": {
"original_words": original_words,
"summary_words": summary_words,
"compression": compression,
"num_segments": len(segments)
}
}
# Remove progress bar when done
progress_bar.empty()
# Display results if available
if hasattr(st.session_state, 'result') and st.session_state.result:
result = st.session_state.result
st.markdown("<div class='section-header'>πŸ“ Executive Summary</div>", unsafe_allow_html=True)
st.markdown(f"<div class='summary-box'>{result['final_summary']}</div>", unsafe_allow_html=True)
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Original Length", f"{result['stats']['original_words']} words")
with col2:
st.metric("Summary Length", f"{result['stats']['summary_words']} words")
with col3:
st.metric("Compression", f"{result['stats']['compression']:.1f}%")
# Display keywords
st.markdown("<div class='section-header'>πŸ”‘ Key Topics</div>", unsafe_allow_html=True)
keyword_html = ""
for keyword in result['keywords']:
keyword_html += f'<span class="keyword-chip">{keyword}</span>'
st.markdown(f"<div>{keyword_html}</div>", unsafe_allow_html=True)
# Detailed summaries (in expander to save space)
with st.expander("View Detailed Segment Summaries"):
for i, summary in enumerate(result['segment_summaries']):
st.markdown(f"**Segment {i+1}**")
st.markdown(f"<div class='summary-box'>{summary}</div>", unsafe_allow_html=True)
st.markdown("---")
# Download option
summary_text = f"# Executive Summary\n\n{result['final_summary']}\n\n"
summary_text += f"## Key Topics\n\n{', '.join(result['keywords'])}\n\n"
summary_text += "## Detailed Segment Summaries\n\n"
for i, summary in enumerate(result['segment_summaries']):
summary_text += f"### Segment {i+1}\n\n{summary}\n\n"
st.download_button(
"Download Summary",
summary_text,
file_name="document_summary.md",
mime="text/markdown"
)
if __name__ == "__main__":
main()