File size: 16,723 Bytes
6fe569d
 
 
 
 
 
6901ce4
6fe569d
af44c43
6901ce4
052edd8
6fe569d
 
6901ce4
6fe569d
af44c43
6fe569d
af44c43
 
6fe569d
 
 
af44c43
 
 
6fe569d
 
af44c43
6fe569d
 
 
 
 
 
 
 
 
af44c43
6901ce4
af44c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe569d
af44c43
 
6fe569d
af44c43
6fe569d
af44c43
 
 
 
 
 
 
6fe569d
 
af44c43
 
 
 
 
 
 
 
 
 
 
 
6fe569d
 
af44c43
 
 
 
 
 
6fe569d
 
6901ce4
6fe569d
6901ce4
af44c43
 
 
 
 
 
e828745
6901ce4
 
e828745
6fe569d
 
af44c43
 
6901ce4
af44c43
621da38
 
 
af44c43
6901ce4
af44c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6901ce4
 
 
af44c43
6fe569d
 
 
af44c43
6fe569d
af44c43
6fe569d
 
 
af44c43
6fe569d
 
 
 
 
 
 
 
d124ecd
6fe569d
6901ce4
 
 
 
 
 
d124ecd
6fe569d
 
 
 
6901ce4
 
 
 
6fe569d
af44c43
 
6fe569d
 
 
 
af44c43
6901ce4
6fe569d
b8f16a6
d124ecd
6901ce4
 
 
d124ecd
6901ce4
d124ecd
 
af44c43
 
6fe569d
 
 
 
 
af44c43
6fe569d
b8f16a6
d124ecd
6901ce4
 
d124ecd
 
 
 
6fe569d
6901ce4
 
 
6fe569d
d124ecd
6fe569d
 
948faf9
6901ce4
6fe569d
 
948faf9
6901ce4
6fe569d
af44c43
b323653
 
af44c43
 
 
 
 
 
 
 
 
 
 
 
 
 
b323653
af44c43
 
 
 
6fe569d
 
 
 
 
 
af44c43
 
 
 
 
 
 
 
 
 
 
6901ce4
6fe569d
 
af44c43
6fe569d
 
 
7edcbdb
2d3e634
6901ce4
af44c43
 
 
 
 
6901ce4
af44c43
 
7edcbdb
 
a95a714
6fe569d
 
af44c43
7edcbdb
 
6fe569d
6901ce4
af44c43
6901ce4
af44c43
 
 
 
 
 
 
 
 
 
 
 
 
 
6901ce4
e38e783
af44c43
 
 
 
 
 
 
6901ce4
 
 
af44c43
 
 
6901ce4
af44c43
 
 
 
 
 
 
c49d373
 
 
 
af44c43
 
 
 
 
 
 
 
6901ce4
 
 
 
af44c43
 
 
 
 
 
6901ce4
af44c43
6901ce4
af44c43
 
 
 
 
 
 
4bd0367
6fe569d
 
 
 
 
 
 
 
 
 
6901ce4
6fe569d
 
af44c43
 
 
 
6fe569d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import base64
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import re
import streamlit as st
from streamlit_tags import st_tags
import sys
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import pipeline

# Notes
# https://huggingface.co/docs/transformers/pad_truncation
# https://stackoverflow.com/questions/76431655/langchain-pypdfloader
# https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846


# file loader and preprocessor
def file_preprocessing(
    file, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
):
    loader = PyMuPDFLoader(file)
    pages = loader.load_and_split()
    # Skip user-specified page(s)
    if (skipfirst == 1) & (skiplast == 0):
        del pages[0]
    elif (skipfirst == 0) & (skiplast == 1):
        del pages[-1]
    elif (skipfirst == 1) & (skiplast == 1):
        del pages[0]
        del pages[-1]
    else:
        pages = pages
    input_text = ""
    for page in pages:
        input_text = input_text + page.page_content
    input_text = re.sub("-\n", "", input_text)
    input_text = re.sub(r"\n", " ", input_text)
    # Initialize a list to store valid sentences
    valid_sentences = []
    # Split the input_text into sentences
    sentences = re.split(r"(?<=[.!?])\s+", input_text)
    # Iterate through each sentence
    for sentence in sentences:
        # Check if any exclude_word is present in the sentence
        if any(word in sentence for word in exclude_words):
            continue  # Skip sentences with exclude_words
        valid_sentences.append(sentence)
    final_input_text = " ".join(valid_sentences)
    print("\n############## New article ##############\n")
    print("Cleaned and formatted input text:\n")
    print(final_input_text)
    print("\nExcluded words: " + str(exclude_words))
    print("\nChunking input text...\n")
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,  # Number of characters
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],  # Default list
    )
    text_chunks = text_splitter.split_text(final_input_text)
    print("Number of chunks: " + str(len(text_chunks)), end="")
    chunks = ""
    for text in text_chunks:
        chunks = chunks + "\n\n" + text
    print(chunks)
    return final_input_text, text_chunks


# Function to count words in the input
def preprocessing_word_count(
    filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
):
    final_input_text, text_chunks = file_preprocessing(
        filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words
    )
    text_length = len(re.findall(r"\w+", final_input_text))
    print("\nInput word count: " f"{text_length:,}")
    print("Chunk size: " f"{chunk_size:,}")
    print("Chunk overlap: %s" % chunk_overlap)
    return final_input_text, text_chunks, text_length


# LLM pipeline for summarization
def llm_pipeline(
    tokenizer, base_model, final_input_text, model_source, minimum_token_number
):
    summarizer = pipeline(
        task="summarization",
        model=base_model,
        tokenizer=tokenizer,
        truncation=True,
    )
    print("Model source: %s" % (model_source))
    print("Summarizing...\n")
    result = summarizer(
        final_input_text,
        min_length=minimum_token_number,
        max_length=tokenizer.model_max_length,
    )
    summary = result[0]["summary_text"]
    print("Summary text:\n")
    print(summary)
    return summary


# Function to count words in the summary
def postprocessing_word_count(summary):
    text_length = len(re.findall(r"\w+", summary))
    print("\nSummary word count: " f"{text_length:,}")
    return text_length


# Function to clean bart summary text
def clean_summary_text(summary):
    # Remove next line
    summary_cleaned_1 = re.sub(r"\n\s+", "", summary)
    # Remove whitespace
    summary_cleaned_2 = summary_cleaned_1.strip()
    # Remove any spaces before punctuation (bart)
    summary_cleaned_3 = re.sub(r"\s+([.,;:)!?](?:\s|$))", r"\1", summary_cleaned_2)
    # Remove any spaces after "("
    summary_cleaned_4 = re.sub(r"\(\s", r"(", summary_cleaned_3)
    # Remove any spaces betweeen the closing parenthesis and other puncuation
    summary_cleaned_5 = re.sub(r"(\))\s+([,.:;?!])", r"\1\2", summary_cleaned_4)
    return summary_cleaned_5


# Function to covert bart summary to sentence case
def convert_to_sentence_case(summary):
    # Split the paragraph into sentences based on '.', '!', or '?'
    sentences = re.split(r"(?<=[.!?])\s+", summary)
    # Convert to sentence case and join the sentences back together
    formatted_sentences = [sentence.capitalize() for sentence in sentences]
    return " ".join(formatted_sentences)


def remove_duplicate_sentences(summary):
    # Split the paragraph into sentences
    sentences = re.split(r"(?<=[.!?])\s+", summary)
    # Initialize a set to store unique sentences
    unique_sentences = set()
    # Initialize a list to store valid sentences
    valid_sentences = []
    # Iterate through each sentence
    for sentence in sentences:
        # Check if the sentence is unique
        if sentence not in unique_sentences:
            unique_sentences.add(sentence)
            valid_sentences.append(sentence)
    # Join the remaining valid sentences to create the final_summary
    final_summary = " ".join(valid_sentences)
    return final_summary


# Function to remove incomplete last sentence from summary
def remove_incomplete_last_sentence(summary):
    # Split the paragraph into sentences based on '.', '!', or '?'
    sentences = re.split(r"(?<=[.!?])\s+", summary)
    # Check if the last sentence lacks punctuation at the end
    if (
        sentences
        and sentences[-1].strip()
        and not sentences[-1].strip().endswith((".", "!", "?"))
    ):
        # Remove the last sentence from the paragraph
        sentences.pop()
    # Join the sentences back together
    return " ".join(sentences)


@st.cache_data(ttl=60 * 60)
# Function to display the PDF
def displayPDF(file):
    with open(file, "rb") as f:
        base64_pdf = base64.b64encode(f.read()).decode("utf-8")
    # Embed pdf in html
    pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
    # Display file
    st.markdown(pdf_display, unsafe_allow_html=True)


# Streamlit code
st.set_page_config(layout="wide")


def main():
    st.title("RASA: Research Article Summarization App")
    uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
    if uploaded_file is not None:
        st.subheader("Options")
        col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
        with col1:
            model_source_names = ["Cached model", "Download model"]
            model_source = st.radio(
                "For development:",
                model_source_names,
                help="Defaults to a cached model; downloading will take longer",
            )
        with col2:
            model_names = [
                "T5-Small",
                "BART",
            ]
            selected_model = st.radio(
                "Select a model to use:",
                model_names,
            )
            if selected_model == "BART":
                chunk_size = 800
                chunk_overlap = 80
                checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
                tokenizer = AutoTokenizer.from_pretrained(
                    checkpoint,
                    truncation=True,
                    model_max_length=512,
                    trust_remote_code=True,
                )
                if model_source == "Download model":
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(
                        checkpoint,
                        torch_dtype=torch.float32,
                        trust_remote_code=True,
                    )
                else:
                    base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
            else:
                chunk_size = 1000
                chunk_overlap = 100
                checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
                tokenizer = AutoTokenizer.from_pretrained(
                    checkpoint,
                    truncation=True,
                    legacy=False,
                    model_max_length=512,
                )
                if model_source == "Download model":
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(
                        checkpoint,
                        torch_dtype=torch.float32,
                    )
                else:
                    base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
        with col3:
            st.write("Skip any pages?")
            skipfirst = st.checkbox(
                "Skip first page", help="Select if your PDF has a cover page"
            )
            skiplast = st.checkbox("Skip last page")
        with col4:
            st.write("Background information (links open in a new window)")
            st.write(
                "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
                "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
            )
            st.write(
                "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
                "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
            )
        exclude_words = st_tags(
            label="Enter word(s) to exclude from the summary:",
            text="Press enter to add",
        )
        col1, col2, col3 = st.columns([1, 1, 5])
        with col1:
            minimum_token_number = st.number_input(
                "Minimum number of tokens",
                value=200,
                step=25,
                min_value=0,
                max_value=512,
                help="Use a larger number of tokens to increase summary length",
            )
        with col3:
            st.subheader("Notes")
            st.write(
                "To remove content from the summary, try copying and pasting the word(s) to exclude in the box above and summarize again."
            )
            st.write(
                "To lengthen or shorten the summary, increase or decrease the minimum number of tokens to the left and summarize again."
            )
        if st.button("Summarize"):
            col1, col2 = st.columns(2)
            filepath = "data/" + uploaded_file.name
            with open(filepath, "wb") as temp_file:
                temp_file.write(uploaded_file.read())
            with col1:
                (
                    final_input_text,
                    text_chunks,
                    preprocessing_text_length,
                ) = preprocessing_word_count(
                    filepath,
                    skipfirst,
                    skiplast,
                    chunk_size,
                    chunk_overlap,
                    exclude_words,
                )
                st.info(
                    "Uploaded PDF&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
                    f"{preprocessing_text_length:,}"
                )
                pdf_viewer = displayPDF(filepath)
            with col2:
                start = time.time()
                with st.spinner("Summarizing..."):
                    summary = llm_pipeline(
                        tokenizer,
                        base_model,
                        final_input_text,
                        model_source,
                        minimum_token_number,
                    )
                    # Count summary words
                    postprocessing_text_length = postprocessing_word_count(summary)
                end = time.time()
                duration = end - start
                print("Duration: " f"{duration:.0f}" + " seconds")
                st.info(
                    "PDF Summary&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
                    f"{postprocessing_text_length:,}"
                    + "&nbsp;&nbsp;|&nbsp;&nbsp;Summarization time: "
                    f"{duration:.0f}" + " seconds"
                )
                if selected_model == "BART":
                    # Use regex to clean the unformatted bart summary
                    summary_cleaned = clean_summary_text(summary)
                    # Convert to sentence case
                    summary_cleaned_sentence_case = convert_to_sentence_case(
                        summary_cleaned
                    )
                    # Remove duplicate sentences
                    summary_cleaned_sentence_case_dedup = remove_duplicate_sentences(
                        summary_cleaned_sentence_case
                    )
                    # Remove incomplete last sentence
                    summary_cleaned_final = remove_incomplete_last_sentence(
                        summary_cleaned_sentence_case_dedup
                    )
                    st.success(summary_cleaned_final)
                    with st.expander("Unformatted output"):
                        st.write(summary)
                else:  # T5 model
                    # Remove duplicate sentences
                    summary_dedup = remove_duplicate_sentences(summary)
                    # Remove incomplete last sentence
                    summary_final = remove_incomplete_last_sentence(summary_dedup)
                    st.success(summary_final)
                    with st.expander("Unformatted output"):
                        st.write(summary)
            col1 = st.columns(1)
            url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
            st.info("Additional information")
            input_ids = tokenizer.encode(
                final_input_text, add_special_tokens=True, truncation=True
            )
            st.write(
                "Maximum number of tokens generated for inputs into the model: %s"
                % f"{len(input_ids):,}"
            )
            st.write("First 10 tokens:")
            first_10_tokens = input_ids[:10]
            first_10_tokens_text = tokenizer.convert_ids_to_tokens(first_10_tokens)
            st.write(first_10_tokens_text)
            st.write("First 500 tokens:")
            first_500_tokens = input_ids[:500]
            first_500_tokens_text = tokenizer.convert_ids_to_tokens(first_500_tokens)
            st.write(first_500_tokens_text)
            st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
            st.write(
                "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_size=%s"
                % chunk_size
            )
            st.write(
                "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_overlap=%s"
                % chunk_overlap
            )
            st.write(
                "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;length_function=len"
            )
            st.write("\n")
            st.write("Number of input text chunks: " + str(len(text_chunks)))
            st.write("")
            st.write("First three chunks:")
            st.write("\n")
            st.write(text_chunks[0])
            st.write("")
            st.write(text_chunks[1])
            st.write("")
            st.write(text_chunks[2])
            st.write("\n")
            st.write(
                "Extracted and cleaned text, less sentences containing excluded words:"
            )
            st.write("")
            st.write(final_input_text)


st.markdown(
    """<style>
div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
    font-size: 1rem;
    font-weight: 400;
}
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
    margin-bottom: -15px;
}
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
    margin-bottom: -15px;
}
div[class*="stNumberInput"] > label > div[data-testid="stMarkdownContainer"] > p {
    font-size: 1rem;
    font-weight: 400;
}
body > a {
    text-decoration: underline;
}
    </style>
    """,
    unsafe_allow_html=True,
)


if __name__ == "__main__":
    main()