rasa / app.py
wjjessen's picture
add first 500 tokens
c49d373
import base64
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import re
import streamlit as st
from streamlit_tags import st_tags
import sys
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import pipeline
# Notes
# https://huggingface.co/docs/transformers/pad_truncation
# https://stackoverflow.com/questions/76431655/langchain-pypdfloader
# https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
# file loader and preprocessor
def file_preprocessing(
file, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
):
loader = PyMuPDFLoader(file)
pages = loader.load_and_split()
# Skip user-specified page(s)
if (skipfirst == 1) & (skiplast == 0):
del pages[0]
elif (skipfirst == 0) & (skiplast == 1):
del pages[-1]
elif (skipfirst == 1) & (skiplast == 1):
del pages[0]
del pages[-1]
else:
pages = pages
input_text = ""
for page in pages:
input_text = input_text + page.page_content
input_text = re.sub("-\n", "", input_text)
input_text = re.sub(r"\n", " ", input_text)
# Initialize a list to store valid sentences
valid_sentences = []
# Split the input_text into sentences
sentences = re.split(r"(?<=[.!?])\s+", input_text)
# Iterate through each sentence
for sentence in sentences:
# Check if any exclude_word is present in the sentence
if any(word in sentence for word in exclude_words):
continue # Skip sentences with exclude_words
valid_sentences.append(sentence)
final_input_text = " ".join(valid_sentences)
print("\n############## New article ##############\n")
print("Cleaned and formatted input text:\n")
print(final_input_text)
print("\nExcluded words: " + str(exclude_words))
print("\nChunking input text...\n")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, # Number of characters
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", " ", ""], # Default list
)
text_chunks = text_splitter.split_text(final_input_text)
print("Number of chunks: " + str(len(text_chunks)), end="")
chunks = ""
for text in text_chunks:
chunks = chunks + "\n\n" + text
print(chunks)
return final_input_text, text_chunks
# Function to count words in the input
def preprocessing_word_count(
filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
):
final_input_text, text_chunks = file_preprocessing(
filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
)
text_length = len(re.findall(r"\w+", final_input_text))
print("\nInput word count: " f"{text_length:,}")
print("Chunk size: " f"{chunk_size:,}")
print("Chunk overlap: %s" % chunk_overlap)
return final_input_text, text_chunks, text_length
# LLM pipeline for summarization
def llm_pipeline(
tokenizer, base_model, final_input_text, model_source, minimum_token_number
):
summarizer = pipeline(
task="summarization",
model=base_model,
tokenizer=tokenizer,
truncation=True,
)
print("Model source: %s" % (model_source))
print("Summarizing...\n")
result = summarizer(
final_input_text,
min_length=minimum_token_number,
max_length=tokenizer.model_max_length,
)
summary = result[0]["summary_text"]
print("Summary text:\n")
print(summary)
return summary
# Function to count words in the summary
def postprocessing_word_count(summary):
text_length = len(re.findall(r"\w+", summary))
print("\nSummary word count: " f"{text_length:,}")
return text_length
# Function to clean bart summary text
def clean_summary_text(summary):
# Remove next line
summary_cleaned_1 = re.sub(r"\n\s+", "", summary)
# Remove whitespace
summary_cleaned_2 = summary_cleaned_1.strip()
# Remove any spaces before punctuation (bart)
summary_cleaned_3 = re.sub(r"\s+([.,;:)!?](?:\s|$))", r"\1", summary_cleaned_2)
# Remove any spaces after "("
summary_cleaned_4 = re.sub(r"\(\s", r"(", summary_cleaned_3)
# Remove any spaces betweeen the closing parenthesis and other puncuation
summary_cleaned_5 = re.sub(r"(\))\s+([,.:;?!])", r"\1\2", summary_cleaned_4)
return summary_cleaned_5
# Function to covert bart summary to sentence case
def convert_to_sentence_case(summary):
# Split the paragraph into sentences based on '.', '!', or '?'
sentences = re.split(r"(?<=[.!?])\s+", summary)
# Convert to sentence case and join the sentences back together
formatted_sentences = [sentence.capitalize() for sentence in sentences]
return " ".join(formatted_sentences)
def remove_duplicate_sentences(summary):
# Split the paragraph into sentences
sentences = re.split(r"(?<=[.!?])\s+", summary)
# Initialize a set to store unique sentences
unique_sentences = set()
# Initialize a list to store valid sentences
valid_sentences = []
# Iterate through each sentence
for sentence in sentences:
# Check if the sentence is unique
if sentence not in unique_sentences:
unique_sentences.add(sentence)
valid_sentences.append(sentence)
# Join the remaining valid sentences to create the final_summary
final_summary = " ".join(valid_sentences)
return final_summary
# Function to remove incomplete last sentence from summary
def remove_incomplete_last_sentence(summary):
# Split the paragraph into sentences based on '.', '!', or '?'
sentences = re.split(r"(?<=[.!?])\s+", summary)
# Check if the last sentence lacks punctuation at the end
if (
sentences
and sentences[-1].strip()
and not sentences[-1].strip().endswith((".", "!", "?"))
):
# Remove the last sentence from the paragraph
sentences.pop()
# Join the sentences back together
return " ".join(sentences)
@st.cache_data(ttl=60 * 60)
# Function to display the PDF
def displayPDF(file):
with open(file, "rb") as f:
base64_pdf = base64.b64encode(f.read()).decode("utf-8")
# Embed pdf in html
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
# Display file
st.markdown(pdf_display, unsafe_allow_html=True)
# Streamlit code
st.set_page_config(layout="wide")
def main():
st.title("RASA: Research Article Summarization App")
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
if uploaded_file is not None:
st.subheader("Options")
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
with col1:
model_source_names = ["Cached model", "Download model"]
model_source = st.radio(
"For development:",
model_source_names,
help="Defaults to a cached model; downloading will take longer",
)
with col2:
model_names = [
"T5-Small",
"BART",
]
selected_model = st.radio(
"Select a model to use:",
model_names,
)
if selected_model == "BART":
chunk_size = 800
chunk_overlap = 80
checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
truncation=True,
model_max_length=512,
trust_remote_code=True,
)
if model_source == "Download model":
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
torch_dtype=torch.float32,
trust_remote_code=True,
)
else:
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
else:
chunk_size = 1000
chunk_overlap = 100
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
truncation=True,
legacy=False,
model_max_length=512,
)
if model_source == "Download model":
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
torch_dtype=torch.float32,
)
else:
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
with col3:
st.write("Skip any pages?")
skipfirst = st.checkbox(
"Skip first page", help="Select if your PDF has a cover page"
)
skiplast = st.checkbox("Skip last page")
with col4:
st.write("Background information (links open in a new window)")
st.write(
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
"&nbsp;&nbsp;|&nbsp;&nbsp;Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
)
st.write(
"Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
"&nbsp;&nbsp;|&nbsp;&nbsp;Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
)
exclude_words = st_tags(
label="Enter word(s) to exclude from the summary:",
text="Press enter to add",
)
col1, col2, col3 = st.columns([1, 1, 5])
with col1:
minimum_token_number = st.number_input(
"Minimum number of tokens",
value=200,
step=25,
min_value=0,
max_value=512,
help="Use a larger number of tokens to increase summary length",
)
with col3:
st.subheader("Notes")
st.write(
"To remove content from the summary, try copying and pasting the word(s) to exclude in the box above and summarize again."
)
st.write(
"To lengthen or shorten the summary, increase or decrease the minimum number of tokens to the left and summarize again."
)
if st.button("Summarize"):
col1, col2 = st.columns(2)
filepath = "data/" + uploaded_file.name
with open(filepath, "wb") as temp_file:
temp_file.write(uploaded_file.read())
with col1:
(
final_input_text,
text_chunks,
preprocessing_text_length,
) = preprocessing_word_count(
filepath,
skipfirst,
skiplast,
chunk_size,
chunk_overlap,
exclude_words,
)
st.info(
"Uploaded PDF&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
f"{preprocessing_text_length:,}"
)
pdf_viewer = displayPDF(filepath)
with col2:
start = time.time()
with st.spinner("Summarizing..."):
summary = llm_pipeline(
tokenizer,
base_model,
final_input_text,
model_source,
minimum_token_number,
)
# Count summary words
postprocessing_text_length = postprocessing_word_count(summary)
end = time.time()
duration = end - start
print("Duration: " f"{duration:.0f}" + " seconds")
st.info(
"PDF Summary&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
f"{postprocessing_text_length:,}"
+ "&nbsp;&nbsp;|&nbsp;&nbsp;Summarization time: "
f"{duration:.0f}" + " seconds"
)
if selected_model == "BART":
# Use regex to clean the unformatted bart summary
summary_cleaned = clean_summary_text(summary)
# Convert to sentence case
summary_cleaned_sentence_case = convert_to_sentence_case(
summary_cleaned
)
# Remove duplicate sentences
summary_cleaned_sentence_case_dedup = remove_duplicate_sentences(
summary_cleaned_sentence_case
)
# Remove incomplete last sentence
summary_cleaned_final = remove_incomplete_last_sentence(
summary_cleaned_sentence_case_dedup
)
st.success(summary_cleaned_final)
with st.expander("Unformatted output"):
st.write(summary)
else: # T5 model
# Remove duplicate sentences
summary_dedup = remove_duplicate_sentences(summary)
# Remove incomplete last sentence
summary_final = remove_incomplete_last_sentence(summary_dedup)
st.success(summary_final)
with st.expander("Unformatted output"):
st.write(summary)
col1 = st.columns(1)
url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
st.info("Additional information")
input_ids = tokenizer.encode(
final_input_text, add_special_tokens=True, truncation=True
)
st.write(
"Maximum number of tokens generated for inputs into the model: %s"
% f"{len(input_ids):,}"
)
st.write("First 10 tokens:")
first_10_tokens = input_ids[:10]
first_10_tokens_text = tokenizer.convert_ids_to_tokens(first_10_tokens)
st.write(first_10_tokens_text)
st.write("First 500 tokens:")
first_500_tokens = input_ids[:500]
first_500_tokens_text = tokenizer.convert_ids_to_tokens(first_500_tokens)
st.write(first_500_tokens_text)
st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
st.write(
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_size=%s"
% chunk_size
)
st.write(
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_overlap=%s"
% chunk_overlap
)
st.write(
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;length_function=len"
)
st.write("\n")
st.write("Number of input text chunks: " + str(len(text_chunks)))
st.write("")
st.write("First three chunks:")
st.write("\n")
st.write(text_chunks[0])
st.write("")
st.write(text_chunks[1])
st.write("")
st.write(text_chunks[2])
st.write("\n")
st.write(
"Extracted and cleaned text, less sentences containing excluded words:"
)
st.write("")
st.write(final_input_text)
st.markdown(
"""<style>
div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
font-size: 1rem;
font-weight: 400;
}
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
margin-bottom: -15px;
}
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
margin-bottom: -15px;
}
div[class*="stNumberInput"] > label > div[data-testid="stMarkdownContainer"] > p {
font-size: 1rem;
font-weight: 400;
}
body > a {
text-decoration: underline;
}
</style>
""",
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()