Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
import os | |
import tempfile | |
import time | |
import nltk | |
import re | |
from bs4 import BeautifulSoup | |
import requests | |
import PyPDF2 | |
from docx import Document | |
from nltk.tokenize import sent_tokenize | |
from nltk.corpus import stopwords | |
# Set page configuration | |
st.set_page_config( | |
page_title="Document Summarizer", | |
page_icon="π", | |
layout="wide" | |
) | |
# Download NLTK data at startup | |
def download_nltk_data(): | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
return True | |
# Initialize NLTK | |
nltk_initialized = download_nltk_data() | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main-header {font-size: 2rem; font-weight: bold; margin-bottom: 1rem;} | |
.summary-box {background-color: #e3e8f0; padding: 1rem; border-radius: 0.5rem; margin-top: 1rem; border: 1px solid #b0bec5; color: #263238;} | |
.keyword-chip {background-color: #0277bd; padding: 5px 10px; margin: 5px; border-radius: 15px; display: inline-block; color: white; font-weight: 500;} | |
.section-header {color: #01579b; font-weight: bold; margin-top: 1.5rem; margin-bottom: 0.5rem;} | |
</style> | |
""", unsafe_allow_html=True) | |
# Document processing functions | |
class DocumentProcessor: | |
def __init__(self): | |
self.stopwords = set(stopwords.words('english')) | |
def extract_text_from_pdf(self, file_bytes): | |
"""Extract text from PDF bytes""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_file: | |
temp_file.write(file_bytes) | |
temp_path = temp_file.name | |
text = "" | |
with open(temp_path, 'rb') as file: | |
reader = PyPDF2.PdfReader(file) | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
# Clean up temp file | |
os.unlink(temp_path) | |
return self.clean_text(text) | |
except Exception as e: | |
st.error(f"Error extracting text from PDF: {str(e)}") | |
return "" | |
def extract_text_from_docx(self, file_bytes): | |
"""Extract text from DOCX bytes""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_file: | |
temp_file.write(file_bytes) | |
temp_path = temp_file.name | |
doc = Document(temp_path) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
# Clean up temp file | |
os.unlink(temp_path) | |
return self.clean_text(text) | |
except Exception as e: | |
st.error(f"Error extracting text from DOCX: {str(e)}") | |
return "" | |
def extract_text_from_txt(self, file_bytes): | |
"""Extract text from TXT bytes""" | |
try: | |
text = file_bytes.decode('utf-8') | |
return self.clean_text(text) | |
except UnicodeDecodeError: | |
try: | |
text = file_bytes.decode('latin-1') | |
return self.clean_text(text) | |
except Exception as e: | |
st.error(f"Error decoding text file: {str(e)}") | |
return "" | |
def extract_text_from_url(self, url): | |
"""Extract text from URL""" | |
try: | |
if not url.startswith(('http://', 'https://')): | |
url = 'https://' + url | |
response = requests.get(url, timeout=10) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
# Remove non-content elements | |
for tag in soup(['script', 'style', 'header', 'footer', 'nav']): | |
tag.decompose() | |
# Get main content | |
paragraphs = soup.find_all('p') | |
text = ' '.join([p.get_text().strip() for p in paragraphs]) | |
return self.clean_text(text) | |
except Exception as e: | |
st.error(f"Error extracting text from URL: {str(e)}") | |
return "" | |
def clean_text(self, text): | |
"""Clean extracted text""" | |
# Replace multiple newlines and spaces | |
text = re.sub(r'\n+', '\n', text) | |
text = re.sub(r'\s+', ' ', text) | |
return text.strip() | |
def segment_document(self, text, max_length=1000): | |
"""Break document into segments""" | |
if not text: | |
return [] | |
sentences = sent_tokenize(text) | |
segments = [] | |
current_segment = [] | |
current_length = 0 | |
for sentence in sentences: | |
sentence_length = len(sentence) | |
# If adding this sentence exceeds max length, start a new segment | |
if current_length + sentence_length > max_length and current_segment: | |
segments.append(' '.join(current_segment)) | |
current_segment = [] | |
current_length = 0 | |
current_segment.append(sentence) | |
current_length += sentence_length | |
# Add the last segment | |
if current_segment: | |
segments.append(' '.join(current_segment)) | |
return segments | |
def extract_keywords(self, text, top_n=8): | |
"""Extract keywords from text""" | |
# Simple frequency-based extraction | |
words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower()) | |
words = [word for word in words if word not in self.stopwords] | |
# Count frequencies | |
word_freq = {} | |
for word in words: | |
word_freq[word] = word_freq.get(word, 0) + 1 | |
# Sort by frequency | |
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True) | |
keywords = [word for word, _ in sorted_words[:top_n]] | |
return keywords | |
def process_document(self, file=None, url=None, text=None): | |
"""Process document from various sources""" | |
document_text = "" | |
if file: | |
file_type = file.type | |
file_bytes = file.getvalue() | |
if "pdf" in file_type: | |
document_text = self.extract_text_from_pdf(file_bytes) | |
elif "word" in file_type or file.name.endswith('.docx'): | |
document_text = self.extract_text_from_docx(file_bytes) | |
elif "text" in file_type or file.name.endswith('.txt'): | |
document_text = self.extract_text_from_txt(file_bytes) | |
else: | |
st.error(f"Unsupported file type: {file_type}") | |
return None | |
elif url: | |
document_text = self.extract_text_from_url(url) | |
elif text: | |
document_text = text | |
if not document_text: | |
st.error("Could not extract text from the document") | |
return None | |
# Process the document | |
segments = self.segment_document(document_text) | |
keywords = self.extract_keywords(document_text) | |
return { | |
"full_text": document_text, | |
"segments": segments, | |
"keywords": keywords | |
} | |
# Model loading - with smaller model for CPU | |
def load_model(): | |
# Load a smaller model better suited for CPU | |
model_name = "facebook/bart-large-cnn" # Could use smaller model if needed | |
# Make sure we use low precision for CPU | |
tokenizer = BartTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
return tokenizer, model | |
def generate_summary(text, tokenizer, model, max_length=150, min_length=40): | |
"""Generate summary with BART""" | |
if not text or len(text.strip()) == 0: | |
return "" | |
# Tokenize with truncation | |
inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt") | |
# Generate summary with less beams for CPU | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
num_beams=2, # Reduced for CPU | |
min_length=min_length, | |
max_length=max_length, | |
length_penalty=2.0, | |
early_stopping=True, | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def main(): | |
st.markdown("<div class='main-header'>π Document Summarizer</div>", unsafe_allow_html=True) | |
st.write("Upload a document, enter a URL, or paste text to get an AI-powered summary") | |
# Initialize processor | |
processor = DocumentProcessor() | |
# Input tabs | |
tab1, tab2, tab3 = st.tabs(["π Upload", "π URL", "βοΈ Text"]) | |
with tab1: | |
uploaded_file = st.file_uploader("Upload a document", type=["pdf", "docx", "txt"]) | |
if uploaded_file: | |
st.session_state.input_type = "file" | |
st.session_state.file = uploaded_file | |
with tab2: | |
url = st.text_input("Enter a webpage URL") | |
if url: | |
st.session_state.input_type = "url" | |
st.session_state.url = url | |
with tab3: | |
text = st.text_area("Paste your text", height=200) | |
if text: | |
st.session_state.input_type = "text" | |
st.session_state.text = text | |
# Process button | |
if st.button("Generate Summary", type="primary"): | |
with st.spinner("Processing document..."): | |
# Get input based on selected tab | |
doc_input = None | |
if hasattr(st.session_state, 'input_type'): | |
if st.session_state.input_type == "file" and hasattr(st.session_state, 'file'): | |
doc_input = processor.process_document(file=st.session_state.file) | |
elif st.session_state.input_type == "url" and hasattr(st.session_state, 'url'): | |
doc_input = processor.process_document(url=st.session_state.url) | |
elif st.session_state.input_type == "text" and hasattr(st.session_state, 'text'): | |
doc_input = processor.process_document(text=st.session_state.text) | |
if not doc_input: | |
st.error("Please provide a document, URL, or text") | |
return | |
# Load model (cached) | |
try: | |
tokenizer, model = load_model() | |
st.session_state.model_loaded = True | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# Process segments | |
progress_bar = st.progress(0) | |
segments = doc_input["segments"] | |
segment_summaries = [] | |
for i, segment in enumerate(segments): | |
# Update progress | |
progress = (i + 1) / max(1, len(segments)) | |
progress_bar.progress(progress) | |
if len(segment.split()) > 30: | |
# Summarize segment | |
segment_summary = generate_summary( | |
segment, tokenizer, model, | |
max_length=150, min_length=40 | |
) | |
segment_summaries.append(segment_summary) | |
else: | |
# Short segment, use as is | |
segment_summaries.append(segment) | |
# Create final summary | |
combined_text = " ".join(segment_summaries) | |
if len(combined_text.split()) > 100: | |
final_summary = generate_summary( | |
combined_text, tokenizer, model, | |
max_length=250, min_length=100 | |
) | |
else: | |
final_summary = combined_text | |
# Calculate stats | |
original_words = len(doc_input["full_text"].split()) | |
summary_words = len(final_summary.split()) | |
compression = (1 - (summary_words / original_words)) * 100 if original_words > 0 else 0 | |
# Save results | |
st.session_state.result = { | |
"final_summary": final_summary, | |
"segment_summaries": segment_summaries, | |
"keywords": doc_input["keywords"], | |
"stats": { | |
"original_words": original_words, | |
"summary_words": summary_words, | |
"compression": compression, | |
"num_segments": len(segments) | |
} | |
} | |
# Remove progress bar when done | |
progress_bar.empty() | |
# Display results if available | |
if hasattr(st.session_state, 'result') and st.session_state.result: | |
result = st.session_state.result | |
st.markdown("<div class='section-header'>π Executive Summary</div>", unsafe_allow_html=True) | |
st.markdown(f"<div class='summary-box'>{result['final_summary']}</div>", unsafe_allow_html=True) | |
# Display metrics | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Original Length", f"{result['stats']['original_words']} words") | |
with col2: | |
st.metric("Summary Length", f"{result['stats']['summary_words']} words") | |
with col3: | |
st.metric("Compression", f"{result['stats']['compression']:.1f}%") | |
# Display keywords | |
st.markdown("<div class='section-header'>π Key Topics</div>", unsafe_allow_html=True) | |
keyword_html = "" | |
for keyword in result['keywords']: | |
keyword_html += f'<span class="keyword-chip">{keyword}</span>' | |
st.markdown(f"<div>{keyword_html}</div>", unsafe_allow_html=True) | |
# Detailed summaries (in expander to save space) | |
with st.expander("View Detailed Segment Summaries"): | |
for i, summary in enumerate(result['segment_summaries']): | |
st.markdown(f"**Segment {i+1}**") | |
st.markdown(f"<div class='summary-box'>{summary}</div>", unsafe_allow_html=True) | |
st.markdown("---") | |
# Download option | |
summary_text = f"# Executive Summary\n\n{result['final_summary']}\n\n" | |
summary_text += f"## Key Topics\n\n{', '.join(result['keywords'])}\n\n" | |
summary_text += "## Detailed Segment Summaries\n\n" | |
for i, summary in enumerate(result['segment_summaries']): | |
summary_text += f"### Segment {i+1}\n\n{summary}\n\n" | |
st.download_button( | |
"Download Summary", | |
summary_text, | |
file_name="document_summary.md", | |
mime="text/markdown" | |
) | |
if __name__ == "__main__": | |
main() |