import base64 from langchain.chains.summarize import load_summarize_chain from langchain.docstore.document import Document from langchain.document_loaders.pdf import PyMuPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from PyPDF2 import PdfReader import re import streamlit as st from streamlit_tags import st_tags import sys import time import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM from transformers import pipeline # Notes # https://huggingface.co/docs/transformers/pad_truncation # https://stackoverflow.com/questions/76431655/langchain-pypdfloader # https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846 # file loader and preprocessor def file_preprocessing( file, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words ): loader = PyMuPDFLoader(file) pages = loader.load_and_split() # Skip user-specified page(s) if (skipfirst == 1) & (skiplast == 0): del pages[0] elif (skipfirst == 0) & (skiplast == 1): del pages[-1] elif (skipfirst == 1) & (skiplast == 1): del pages[0] del pages[-1] else: pages = pages input_text = "" for page in pages: input_text = input_text + page.page_content input_text = re.sub("-\n", "", input_text) input_text = re.sub(r"\n", " ", input_text) # Initialize a list to store valid sentences valid_sentences = [] # Split the input_text into sentences sentences = re.split(r"(?<=[.!?])\s+", input_text) # Iterate through each sentence for sentence in sentences: # Check if any exclude_word is present in the sentence if any(word in sentence for word in exclude_words): continue # Skip sentences with exclude_words valid_sentences.append(sentence) final_input_text = " ".join(valid_sentences) print("\n############## New article ##############\n") print("Cleaned and formatted input text:\n") print(final_input_text) print("\nExcluded words: " + str(exclude_words)) print("\nChunking input text...\n") text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, # Number of characters chunk_overlap=chunk_overlap, length_function=len, separators=["\n\n", "\n", " ", ""], # Default list ) text_chunks = text_splitter.split_text(final_input_text) print("Number of chunks: " + str(len(text_chunks)), end="") chunks = "" for text in text_chunks: chunks = chunks + "\n\n" + text print(chunks) return final_input_text, text_chunks # Function to count words in the input def preprocessing_word_count( filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words ): final_input_text, text_chunks = file_preprocessing( filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words ) text_length = len(re.findall(r"\w+", final_input_text)) print("\nInput word count: " f"{text_length:,}") print("Chunk size: " f"{chunk_size:,}") print("Chunk overlap: %s" % chunk_overlap) return final_input_text, text_chunks, text_length # LLM pipeline for summarization def llm_pipeline( tokenizer, base_model, final_input_text, model_source, minimum_token_number ): summarizer = pipeline( task="summarization", model=base_model, tokenizer=tokenizer, truncation=True, ) print("Model source: %s" % (model_source)) print("Summarizing...\n") result = summarizer( final_input_text, min_length=minimum_token_number, max_length=tokenizer.model_max_length, ) summary = result[0]["summary_text"] print("Summary text:\n") print(summary) return summary # Function to count words in the summary def postprocessing_word_count(summary): text_length = len(re.findall(r"\w+", summary)) print("\nSummary word count: " f"{text_length:,}") return text_length # Function to clean bart summary text def clean_summary_text(summary): # Remove next line summary_cleaned_1 = re.sub(r"\n\s+", "", summary) # Remove whitespace summary_cleaned_2 = summary_cleaned_1.strip() # Remove any spaces before punctuation (bart) summary_cleaned_3 = re.sub(r"\s+([.,;:)!?](?:\s|$))", r"\1", summary_cleaned_2) # Remove any spaces after "(" summary_cleaned_4 = re.sub(r"\(\s", r"(", summary_cleaned_3) # Remove any spaces betweeen the closing parenthesis and other puncuation summary_cleaned_5 = re.sub(r"(\))\s+([,.:;?!])", r"\1\2", summary_cleaned_4) return summary_cleaned_5 # Function to covert bart summary to sentence case def convert_to_sentence_case(summary): # Split the paragraph into sentences based on '.', '!', or '?' sentences = re.split(r"(?<=[.!?])\s+", summary) # Convert to sentence case and join the sentences back together formatted_sentences = [sentence.capitalize() for sentence in sentences] return " ".join(formatted_sentences) def remove_duplicate_sentences(summary): # Split the paragraph into sentences sentences = re.split(r"(?<=[.!?])\s+", summary) # Initialize a set to store unique sentences unique_sentences = set() # Initialize a list to store valid sentences valid_sentences = [] # Iterate through each sentence for sentence in sentences: # Check if the sentence is unique if sentence not in unique_sentences: unique_sentences.add(sentence) valid_sentences.append(sentence) # Join the remaining valid sentences to create the final_summary final_summary = " ".join(valid_sentences) return final_summary # Function to remove incomplete last sentence from summary def remove_incomplete_last_sentence(summary): # Split the paragraph into sentences based on '.', '!', or '?' sentences = re.split(r"(?<=[.!?])\s+", summary) # Check if the last sentence lacks punctuation at the end if ( sentences and sentences[-1].strip() and not sentences[-1].strip().endswith((".", "!", "?")) ): # Remove the last sentence from the paragraph sentences.pop() # Join the sentences back together return " ".join(sentences) @st.cache_data(ttl=60 * 60) # Function to display the PDF def displayPDF(file): with open(file, "rb") as f: base64_pdf = base64.b64encode(f.read()).decode("utf-8") # Embed pdf in html pdf_display = f'' # Display file st.markdown(pdf_display, unsafe_allow_html=True) # Streamlit code st.set_page_config(layout="wide") def main(): st.title("RASA: Research Article Summarization App") uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"]) if uploaded_file is not None: st.subheader("Options") col1, col2, col3, col4 = st.columns([1, 1, 1, 2]) with col1: model_source_names = ["Cached model", "Download model"] model_source = st.radio( "For development:", model_source_names, help="Defaults to a cached model; downloading will take longer", ) with col2: model_names = [ "T5-Small", "BART", ] selected_model = st.radio( "Select a model to use:", model_names, ) if selected_model == "BART": chunk_size = 800 chunk_overlap = 80 checkpoint = "ccdv/lsg-bart-base-16384-pubmed" tokenizer = AutoTokenizer.from_pretrained( checkpoint, truncation=True, model_max_length=512, trust_remote_code=True, ) if model_source == "Download model": base_model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, torch_dtype=torch.float32, trust_remote_code=True, ) else: base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15" else: chunk_size = 1000 chunk_overlap = 100 checkpoint = "MBZUAI/LaMini-Flan-T5-77M" tokenizer = AutoTokenizer.from_pretrained( checkpoint, truncation=True, legacy=False, model_max_length=512, ) if model_source == "Download model": base_model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, torch_dtype=torch.float32, ) else: base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474" with col3: st.write("Skip any pages?") skipfirst = st.checkbox( "Skip first page", help="Select if your PDF has a cover page" ) skiplast = st.checkbox("Skip last page") with col4: st.write("Background information (links open in a new window)") st.write( "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)" "  |  Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)" ) st.write( "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)" "  |  Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)" ) exclude_words = st_tags( label="Enter word(s) to exclude from the summary:", text="Press enter to add", ) col1, col2, col3 = st.columns([1, 1, 5]) with col1: minimum_token_number = st.number_input( "Minimum number of tokens", value=200, step=25, min_value=0, max_value=512, help="Use a larger number of tokens to increase summary length", ) with col3: st.subheader("Notes") st.write( "To remove content from the summary, try copying and pasting the word(s) to exclude in the box above and summarize again." ) st.write( "To lengthen or shorten the summary, increase or decrease the minimum number of tokens to the left and summarize again." ) if st.button("Summarize"): col1, col2 = st.columns(2) filepath = "data/" + uploaded_file.name with open(filepath, "wb") as temp_file: temp_file.write(uploaded_file.read()) with col1: ( final_input_text, text_chunks, preprocessing_text_length, ) = preprocessing_word_count( filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words, ) st.info( "Uploaded PDF  |  Number of words: " f"{preprocessing_text_length:,}" ) pdf_viewer = displayPDF(filepath) with col2: start = time.time() with st.spinner("Summarizing..."): summary = llm_pipeline( tokenizer, base_model, final_input_text, model_source, minimum_token_number, ) # Count summary words postprocessing_text_length = postprocessing_word_count(summary) end = time.time() duration = end - start print("Duration: " f"{duration:.0f}" + " seconds") st.info( "PDF Summary  |  Number of words: " f"{postprocessing_text_length:,}" + "  |  Summarization time: " f"{duration:.0f}" + " seconds" ) if selected_model == "BART": # Use regex to clean the unformatted bart summary summary_cleaned = clean_summary_text(summary) # Convert to sentence case summary_cleaned_sentence_case = convert_to_sentence_case( summary_cleaned ) # Remove duplicate sentences summary_cleaned_sentence_case_dedup = remove_duplicate_sentences( summary_cleaned_sentence_case ) # Remove incomplete last sentence summary_cleaned_final = remove_incomplete_last_sentence( summary_cleaned_sentence_case_dedup ) st.success(summary_cleaned_final) with st.expander("Unformatted output"): st.write(summary) else: # T5 model # Remove duplicate sentences summary_dedup = remove_duplicate_sentences(summary) # Remove incomplete last sentence summary_final = remove_incomplete_last_sentence(summary_dedup) st.success(summary_final) with st.expander("Unformatted output"): st.write(summary) col1 = st.columns(1) url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846" st.info("Additional information") input_ids = tokenizer.encode( final_input_text, add_special_tokens=True, truncation=True ) st.write( "Maximum number of tokens generated for inputs into the model: %s" % f"{len(input_ids):,}" ) st.write("First 10 tokens:") first_10_tokens = input_ids[:10] first_10_tokens_text = tokenizer.convert_ids_to_tokens(first_10_tokens) st.write(first_10_tokens_text) st.write("First 500 tokens:") first_500_tokens = input_ids[:500] first_500_tokens_text = tokenizer.convert_ids_to_tokens(first_500_tokens) st.write(first_500_tokens_text) st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url) st.write( "        chunk_size=%s" % chunk_size ) st.write( "        chunk_overlap=%s" % chunk_overlap ) st.write( "        length_function=len" ) st.write("\n") st.write("Number of input text chunks: " + str(len(text_chunks))) st.write("") st.write("First three chunks:") st.write("\n") st.write(text_chunks[0]) st.write("") st.write(text_chunks[1]) st.write("") st.write(text_chunks[2]) st.write("\n") st.write( "Extracted and cleaned text, less sentences containing excluded words:" ) st.write("") st.write(final_input_text) st.markdown( """ """, unsafe_allow_html=True, ) if __name__ == "__main__": main()