added section additional information
Browse files
app.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
import base64
|
3 |
from langchain.chains.summarize import load_summarize_chain
|
4 |
from langchain.docstore.document import Document
|
5 |
from langchain.document_loaders.pdf import PyMuPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from PyPDF2 import PdfReader
|
|
|
8 |
import streamlit as st
|
9 |
-
import
|
10 |
import time
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
|
|
13 |
|
14 |
# notes
|
15 |
# https://huggingface.co/docs/transformers/pad_truncation
|
@@ -19,14 +20,6 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausa
|
|
19 |
def file_preprocessing(file, skipfirst, skiplast):
|
20 |
loader = PyMuPDFLoader(file)
|
21 |
pages = loader.load_and_split()
|
22 |
-
print("")
|
23 |
-
print("# pages[0] ##########")
|
24 |
-
print("")
|
25 |
-
print(pages[0])
|
26 |
-
print("")
|
27 |
-
print("# pages ##########")
|
28 |
-
print("")
|
29 |
-
print(pages)
|
30 |
# skip page(s)
|
31 |
if (skipfirst == 1) & (skiplast == 0):
|
32 |
del pages[0]
|
@@ -37,11 +30,15 @@ def file_preprocessing(file, skipfirst, skiplast):
|
|
37 |
del pages[-1]
|
38 |
else:
|
39 |
pages = pages
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
text_splitter = RecursiveCharacterTextSplitter(
|
46 |
chunk_size=1000, # number of characters
|
47 |
chunk_overlap=100,
|
@@ -49,10 +46,9 @@ def file_preprocessing(file, skipfirst, skiplast):
|
|
49 |
separators=["\n\n", "\n", " ", ""], # default list
|
50 |
)
|
51 |
# https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
|
52 |
-
texts = text_splitter.
|
53 |
-
print("Number of tokens:" + str(len(texts)))
|
54 |
-
print("")
|
55 |
-
print("First three tokens:")
|
56 |
print(texts[0])
|
57 |
print("")
|
58 |
print(texts[1])
|
@@ -61,16 +57,24 @@ def file_preprocessing(file, skipfirst, skiplast):
|
|
61 |
print("")
|
62 |
final_texts = ""
|
63 |
for text in texts:
|
64 |
-
final_texts = final_texts + text
|
65 |
-
return final_texts
|
66 |
|
67 |
|
68 |
# function to count words in the input
|
69 |
def preproc_count(filepath, skipfirst, skiplast):
|
70 |
-
input_text = file_preprocessing(filepath, skipfirst, skiplast)
|
71 |
-
|
|
|
72 |
print("Input word count: " f"{text_length:,}")
|
73 |
-
return input_text, text_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
# llm pipeline
|
@@ -79,26 +83,40 @@ def llm_pipeline(tokenizer, base_model, input_text, model_source):
|
|
79 |
"summarization",
|
80 |
model=base_model,
|
81 |
tokenizer=tokenizer,
|
82 |
-
max_length=
|
83 |
-
min_length=
|
84 |
-
truncation=True
|
85 |
)
|
86 |
-
print("Model source: %s" %(model_source))
|
87 |
print("Summarizing...")
|
88 |
result = pipe_sum(input_text)
|
89 |
summary = result[0]["summary_text"]
|
90 |
-
print("Summarization finished")
|
|
|
|
|
|
|
91 |
return summary
|
92 |
|
93 |
|
94 |
# function to count words in the summary
|
95 |
def postproc_count(summary):
|
96 |
-
text_length = len(summary)
|
97 |
print("Summary word count: " f"{text_length:,}")
|
98 |
return text_length
|
99 |
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
# function to display the PDF
|
103 |
def displayPDF(file):
|
104 |
with open(file, "rb") as f:
|
@@ -120,33 +138,37 @@ def main():
|
|
120 |
st.subheader("Options")
|
121 |
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
|
122 |
with col1:
|
123 |
-
model_source_names = [
|
124 |
-
|
125 |
-
"
|
126 |
-
|
127 |
-
|
|
|
128 |
with col2:
|
129 |
model_names = [
|
130 |
"T5-Small",
|
131 |
"BART",
|
132 |
]
|
133 |
-
selected_model = st.radio(
|
|
|
|
|
|
|
|
|
134 |
if selected_model == "BART":
|
135 |
checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
|
136 |
tokenizer = AutoTokenizer.from_pretrained(
|
137 |
checkpoint,
|
138 |
truncation=True,
|
139 |
-
legacy=False,
|
140 |
model_max_length=1000,
|
141 |
-
trust_remote_code=True,
|
142 |
)
|
143 |
if model_source == "Download model":
|
144 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
)
|
149 |
-
else:
|
150 |
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
|
151 |
else:
|
152 |
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
|
@@ -154,28 +176,30 @@ def main():
|
|
154 |
checkpoint,
|
155 |
truncation=True,
|
156 |
legacy=False,
|
157 |
-
model_max_length=1000,
|
158 |
)
|
159 |
if model_source == "Download model":
|
160 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
161 |
-
|
162 |
-
|
163 |
)
|
164 |
else:
|
165 |
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
|
166 |
with col3:
|
167 |
st.write("Skip any pages?")
|
168 |
-
skipfirst = st.checkbox(
|
|
|
|
|
169 |
skiplast = st.checkbox("Skip last page")
|
170 |
with col4:
|
171 |
st.write("Background information (links open in a new window)")
|
172 |
st.write(
|
173 |
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
|
174 |
-
" |
|
175 |
)
|
176 |
st.write(
|
177 |
"Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
|
178 |
-
" |
|
179 |
)
|
180 |
if st.button("Summarize"):
|
181 |
col1, col2 = st.columns(2)
|
@@ -183,7 +207,9 @@ def main():
|
|
183 |
with open(filepath, "wb") as temp_file:
|
184 |
temp_file.write(uploaded_file.read())
|
185 |
with col1:
|
186 |
-
input_text, preproc_text_length = preproc_count(
|
|
|
|
|
187 |
st.info(
|
188 |
"Uploaded PDF | Number of words: "
|
189 |
f"{preproc_text_length:,}"
|
@@ -192,7 +218,9 @@ def main():
|
|
192 |
with col2:
|
193 |
start = time.time()
|
194 |
with st.spinner("Summarizing..."):
|
195 |
-
summary = llm_pipeline(
|
|
|
|
|
196 |
postproc_text_length = postproc_count(summary)
|
197 |
end = time.time()
|
198 |
duration = end - start
|
@@ -203,7 +231,36 @@ def main():
|
|
203 |
+ " | Summarization time: "
|
204 |
f"{duration:.0f}" + " seconds"
|
205 |
)
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
|
209 |
st.markdown(
|
@@ -215,7 +272,7 @@ div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
|
|
215 |
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
|
216 |
margin-bottom: -15px;
|
217 |
}
|
218 |
-
div[class*="stCheckbox"] > label {
|
219 |
margin-bottom: -15px;
|
220 |
}
|
221 |
body > a {
|
|
|
|
|
1 |
import base64
|
2 |
from langchain.chains.summarize import load_summarize_chain
|
3 |
from langchain.docstore.document import Document
|
4 |
from langchain.document_loaders.pdf import PyMuPDFLoader
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from PyPDF2 import PdfReader
|
7 |
+
import re
|
8 |
import streamlit as st
|
9 |
+
import sys
|
10 |
import time
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
13 |
+
from transformers import pipeline
|
14 |
|
15 |
# notes
|
16 |
# https://huggingface.co/docs/transformers/pad_truncation
|
|
|
20 |
def file_preprocessing(file, skipfirst, skiplast):
|
21 |
loader = PyMuPDFLoader(file)
|
22 |
pages = loader.load_and_split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# skip page(s)
|
24 |
if (skipfirst == 1) & (skiplast == 0):
|
25 |
del pages[0]
|
|
|
30 |
del pages[-1]
|
31 |
else:
|
32 |
pages = pages
|
33 |
+
# https://stackoverflow.com/questions/76431655/langchain-pypdfloader
|
34 |
+
content = ""
|
35 |
+
for page in pages:
|
36 |
+
content = content + page.page_content
|
37 |
+
content = re.sub("-\n", "", content)
|
38 |
+
print("\n###### New article ######\n")
|
39 |
+
print("Input text:\n")
|
40 |
+
print(content)
|
41 |
+
print("\nChunking...")
|
42 |
text_splitter = RecursiveCharacterTextSplitter(
|
43 |
chunk_size=1000, # number of characters
|
44 |
chunk_overlap=100,
|
|
|
46 |
separators=["\n\n", "\n", " ", ""], # default list
|
47 |
)
|
48 |
# https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
|
49 |
+
texts = text_splitter.split_text(content)
|
50 |
+
print("Number of tokens: " + str(len(texts)))
|
51 |
+
print("\nFirst three tokens:\n")
|
|
|
52 |
print(texts[0])
|
53 |
print("")
|
54 |
print(texts[1])
|
|
|
57 |
print("")
|
58 |
final_texts = ""
|
59 |
for text in texts:
|
60 |
+
final_texts = final_texts + text
|
61 |
+
return texts, final_texts
|
62 |
|
63 |
|
64 |
# function to count words in the input
|
65 |
def preproc_count(filepath, skipfirst, skiplast):
|
66 |
+
texts, input_text = file_preprocessing(filepath, skipfirst, skiplast)
|
67 |
+
input_text = input_text.replace("-", "")
|
68 |
+
text_length = len(re.findall(r"\w+", input_text))
|
69 |
print("Input word count: " f"{text_length:,}")
|
70 |
+
return texts, input_text, text_length
|
71 |
+
|
72 |
+
|
73 |
+
# function to covert (bart) summary to sentence case
|
74 |
+
def convert_to_sentence_case(text):
|
75 |
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
76 |
+
formatted_sentences = [sentence.capitalize() for sentence in sentences]
|
77 |
+
return " ".join(formatted_sentences)
|
78 |
|
79 |
|
80 |
# llm pipeline
|
|
|
83 |
"summarization",
|
84 |
model=base_model,
|
85 |
tokenizer=tokenizer,
|
86 |
+
max_length=300,
|
87 |
+
min_length=200,
|
88 |
+
truncation=True,
|
89 |
)
|
90 |
+
print("Model source: %s" % (model_source))
|
91 |
print("Summarizing...")
|
92 |
result = pipe_sum(input_text)
|
93 |
summary = result[0]["summary_text"]
|
94 |
+
print("Summarization finished\n")
|
95 |
+
print("Summary text:\n")
|
96 |
+
print(summary)
|
97 |
+
print("")
|
98 |
return summary
|
99 |
|
100 |
|
101 |
# function to count words in the summary
|
102 |
def postproc_count(summary):
|
103 |
+
text_length = len(re.findall(r"\w+", summary))
|
104 |
print("Summary word count: " f"{text_length:,}")
|
105 |
return text_length
|
106 |
|
107 |
|
108 |
+
# function to clean summary text
|
109 |
+
def clean_summary_text(summary):
|
110 |
+
# remove whitespace
|
111 |
+
summary_clean_1 = summary.strip()
|
112 |
+
# remove spaces before punctuation (bart)
|
113 |
+
summary_clean_2 = re.sub(r'\s([,.():;?!"](?:\s|$))', r"\1", summary_clean_1)
|
114 |
+
# convert to sentence case
|
115 |
+
summary_clean_3 = convert_to_sentence_case(summary_clean_2)
|
116 |
+
return summary_clean_3
|
117 |
+
|
118 |
+
|
119 |
+
@st.cache_data(ttl=60 * 60)
|
120 |
# function to display the PDF
|
121 |
def displayPDF(file):
|
122 |
with open(file, "rb") as f:
|
|
|
138 |
st.subheader("Options")
|
139 |
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
|
140 |
with col1:
|
141 |
+
model_source_names = ["Cached model", "Download model"]
|
142 |
+
model_source = st.radio(
|
143 |
+
"For development:",
|
144 |
+
model_source_names,
|
145 |
+
help="Defaults to a cached model; downloading will take longer",
|
146 |
+
)
|
147 |
with col2:
|
148 |
model_names = [
|
149 |
"T5-Small",
|
150 |
"BART",
|
151 |
]
|
152 |
+
selected_model = st.radio(
|
153 |
+
"Select a model to use:",
|
154 |
+
model_names,
|
155 |
+
help="Defauls to T5-Small as it summarizes better than BART",
|
156 |
+
)
|
157 |
if selected_model == "BART":
|
158 |
checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
|
159 |
tokenizer = AutoTokenizer.from_pretrained(
|
160 |
checkpoint,
|
161 |
truncation=True,
|
|
|
162 |
model_max_length=1000,
|
163 |
+
trust_remote_code=True,
|
164 |
)
|
165 |
if model_source == "Download model":
|
166 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
167 |
+
checkpoint,
|
168 |
+
torch_dtype=torch.float32,
|
169 |
+
trust_remote_code=True,
|
170 |
)
|
171 |
+
else:
|
172 |
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
|
173 |
else:
|
174 |
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
|
|
|
176 |
checkpoint,
|
177 |
truncation=True,
|
178 |
legacy=False,
|
179 |
+
model_max_length=1000,
|
180 |
)
|
181 |
if model_source == "Download model":
|
182 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
183 |
+
checkpoint,
|
184 |
+
torch_dtype=torch.float32,
|
185 |
)
|
186 |
else:
|
187 |
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
|
188 |
with col3:
|
189 |
st.write("Skip any pages?")
|
190 |
+
skipfirst = st.checkbox(
|
191 |
+
"Skip first page", help="Select if your PDF has a cover page"
|
192 |
+
)
|
193 |
skiplast = st.checkbox("Skip last page")
|
194 |
with col4:
|
195 |
st.write("Background information (links open in a new window)")
|
196 |
st.write(
|
197 |
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
|
198 |
+
" | Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
|
199 |
)
|
200 |
st.write(
|
201 |
"Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
|
202 |
+
" | Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
|
203 |
)
|
204 |
if st.button("Summarize"):
|
205 |
col1, col2 = st.columns(2)
|
|
|
207 |
with open(filepath, "wb") as temp_file:
|
208 |
temp_file.write(uploaded_file.read())
|
209 |
with col1:
|
210 |
+
texts, input_text, preproc_text_length = preproc_count(
|
211 |
+
filepath, skipfirst, skiplast
|
212 |
+
)
|
213 |
st.info(
|
214 |
"Uploaded PDF | Number of words: "
|
215 |
f"{preproc_text_length:,}"
|
|
|
218 |
with col2:
|
219 |
start = time.time()
|
220 |
with st.spinner("Summarizing..."):
|
221 |
+
summary = llm_pipeline(
|
222 |
+
tokenizer, base_model, input_text, model_source
|
223 |
+
)
|
224 |
postproc_text_length = postproc_count(summary)
|
225 |
end = time.time()
|
226 |
duration = end - start
|
|
|
231 |
+ " | Summarization time: "
|
232 |
f"{duration:.0f}" + " seconds"
|
233 |
)
|
234 |
+
if selected_model == "BART":
|
235 |
+
summary_cleaned = clean_summary_text(summary)
|
236 |
+
st.success(summary_cleaned)
|
237 |
+
with st.expander("Raw output"):
|
238 |
+
st.write(summary)
|
239 |
+
else:
|
240 |
+
st.success(summary)
|
241 |
+
col1 = st.columns(1)
|
242 |
+
url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
|
243 |
+
st.info("Additional information")
|
244 |
+
st.write("")
|
245 |
+
st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
|
246 |
+
st.write(" chunk_size=1000")
|
247 |
+
st.write(
|
248 |
+
" chunk_overlap=100"
|
249 |
+
)
|
250 |
+
st.write(
|
251 |
+
" length_function=len"
|
252 |
+
)
|
253 |
+
st.write("")
|
254 |
+
st.write("Number of tokens generated: " + str(len(texts)))
|
255 |
+
st.write("")
|
256 |
+
st.write("First three tokens:")
|
257 |
+
st.write("")
|
258 |
+
st.write(texts[0])
|
259 |
+
st.write("")
|
260 |
+
st.write(texts[1])
|
261 |
+
st.write("")
|
262 |
+
st.write(texts[2])
|
263 |
+
st.write("")
|
264 |
|
265 |
|
266 |
st.markdown(
|
|
|
272 |
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
|
273 |
margin-bottom: -15px;
|
274 |
}
|
275 |
+
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
|
276 |
margin-bottom: -15px;
|
277 |
}
|
278 |
body > a {
|