Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from PyPDF2 import PdfReader | |
# Function to extract text from PDF | |
def extract_text_from_pdf(pdf_file): | |
try: | |
reader = PdfReader(pdf_file) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text | |
except Exception: | |
return "" | |
# Function to extract text from a text file | |
def extract_text_from_txt(txt_file): | |
return txt_file.read().decode("utf-8") | |
# Function to split text into chunks | |
def split_text(text, max_chunk_size=512): | |
words = text.split() | |
for i in range(0, len(words), max_chunk_size): | |
yield " ".join(words[i:i + max_chunk_size]) | |
def summarize_chunk(chunk, summarizer): | |
input_length = len(chunk.split()) | |
max_length = min(130, int(0.3 * input_length)) | |
min_length = max(30, int(0.1 * input_length)) | |
return summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] | |
def main(): | |
device = 0 if torch.cuda.is_available() else -1 | |
model_name = "sshleifer/distilbart-cnn-12-6" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device) | |
st.subheader("Generate Summary for PDF or Text Files") | |
uploaded_file = st.file_uploader("Upload a PDF or Text file", type=["pdf", "txt"]) | |
user_text = st.text_area("Or write your text here:") | |
if (uploaded_file or user_text.strip()) and summarizer: | |
try: | |
text = "" | |
if uploaded_file: | |
if uploaded_file.type == "application/pdf": | |
text = extract_text_from_pdf(uploaded_file) | |
elif uploaded_file.type == "text/plain": | |
text = extract_text_from_txt(uploaded_file) | |
else: | |
st.error("Unsupported file type.") | |
else: | |
text = user_text.strip() | |
if not text: | |
st.warning("No text to summarize.") | |
else: | |
chunks = list(split_text(text)) | |
summaries = [summarize_chunk(chunk, summarizer) for chunk in chunks] | |
summary = " ".join(summaries) | |
st.subheader("Summary") | |
st.write(summary) | |
if st.button("Upload another file"): | |
st.experimental_rerun() | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |