wjjessen commited on
Commit
6901ce4
1 Parent(s): a95a714

added section additional information

Browse files
Files changed (1) hide show
  1. app.py +110 -53
app.py CHANGED
@@ -1,15 +1,16 @@
1
- from transformers import pipeline
2
  import base64
3
  from langchain.chains.summarize import load_summarize_chain
4
  from langchain.docstore.document import Document
5
  from langchain.document_loaders.pdf import PyMuPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from PyPDF2 import PdfReader
 
8
  import streamlit as st
9
- import textwrap as tw
10
  import time
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
 
13
 
14
  # notes
15
  # https://huggingface.co/docs/transformers/pad_truncation
@@ -19,14 +20,6 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausa
19
  def file_preprocessing(file, skipfirst, skiplast):
20
  loader = PyMuPDFLoader(file)
21
  pages = loader.load_and_split()
22
- print("")
23
- print("# pages[0] ##########")
24
- print("")
25
- print(pages[0])
26
- print("")
27
- print("# pages ##########")
28
- print("")
29
- print(pages)
30
  # skip page(s)
31
  if (skipfirst == 1) & (skiplast == 0):
32
  del pages[0]
@@ -37,11 +30,15 @@ def file_preprocessing(file, skipfirst, skiplast):
37
  del pages[-1]
38
  else:
39
  pages = pages
40
- print("")
41
- print("# pages after skip(s) ##########")
42
- print("")
43
- print(pages)
44
- print("")
 
 
 
 
45
  text_splitter = RecursiveCharacterTextSplitter(
46
  chunk_size=1000, # number of characters
47
  chunk_overlap=100,
@@ -49,10 +46,9 @@ def file_preprocessing(file, skipfirst, skiplast):
49
  separators=["\n\n", "\n", " ", ""], # default list
50
  )
51
  # https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
52
- texts = text_splitter.split_documents(pages)
53
- print("Number of tokens:" + str(len(texts)))
54
- print("")
55
- print("First three tokens:")
56
  print(texts[0])
57
  print("")
58
  print(texts[1])
@@ -61,16 +57,24 @@ def file_preprocessing(file, skipfirst, skiplast):
61
  print("")
62
  final_texts = ""
63
  for text in texts:
64
- final_texts = final_texts + text.page_content
65
- return final_texts
66
 
67
 
68
  # function to count words in the input
69
  def preproc_count(filepath, skipfirst, skiplast):
70
- input_text = file_preprocessing(filepath, skipfirst, skiplast)
71
- text_length = len(input_text)
 
72
  print("Input word count: " f"{text_length:,}")
73
- return input_text, text_length
 
 
 
 
 
 
 
74
 
75
 
76
  # llm pipeline
@@ -79,26 +83,40 @@ def llm_pipeline(tokenizer, base_model, input_text, model_source):
79
  "summarization",
80
  model=base_model,
81
  tokenizer=tokenizer,
82
- max_length=600,
83
- min_length=300,
84
- truncation=True
85
  )
86
- print("Model source: %s" %(model_source))
87
  print("Summarizing...")
88
  result = pipe_sum(input_text)
89
  summary = result[0]["summary_text"]
90
- print("Summarization finished")
 
 
 
91
  return summary
92
 
93
 
94
  # function to count words in the summary
95
  def postproc_count(summary):
96
- text_length = len(summary)
97
  print("Summary word count: " f"{text_length:,}")
98
  return text_length
99
 
100
 
101
- @st.cache_data(ttl=60*60)
 
 
 
 
 
 
 
 
 
 
 
102
  # function to display the PDF
103
  def displayPDF(file):
104
  with open(file, "rb") as f:
@@ -120,33 +138,37 @@ def main():
120
  st.subheader("Options")
121
  col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
122
  with col1:
123
- model_source_names = [
124
- "Cached model",
125
- "Download model"
126
- ]
127
- model_source = st.radio("For development:", model_source_names)
 
128
  with col2:
129
  model_names = [
130
  "T5-Small",
131
  "BART",
132
  ]
133
- selected_model = st.radio("Select a model to use:", model_names)
 
 
 
 
134
  if selected_model == "BART":
135
  checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
136
  tokenizer = AutoTokenizer.from_pretrained(
137
  checkpoint,
138
  truncation=True,
139
- legacy=False,
140
  model_max_length=1000,
141
- trust_remote_code=True,
142
  )
143
  if model_source == "Download model":
144
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
145
- checkpoint,
146
- torch_dtype=torch.float32,
147
- trust_remote_code=True,
148
  )
149
- else:
150
  base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
151
  else:
152
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
@@ -154,28 +176,30 @@ def main():
154
  checkpoint,
155
  truncation=True,
156
  legacy=False,
157
- model_max_length=1000,
158
  )
159
  if model_source == "Download model":
160
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
161
- checkpoint,
162
- torch_dtype=torch.float32,
163
  )
164
  else:
165
  base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
166
  with col3:
167
  st.write("Skip any pages?")
168
- skipfirst = st.checkbox("Skip first page")
 
 
169
  skiplast = st.checkbox("Skip last page")
170
  with col4:
171
  st.write("Background information (links open in a new window)")
172
  st.write(
173
  "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
174
- "  |  Specific model: [MBZUAI/LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
175
  )
176
  st.write(
177
  "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
178
- "  |  Specific model: [ccdv/lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
179
  )
180
  if st.button("Summarize"):
181
  col1, col2 = st.columns(2)
@@ -183,7 +207,9 @@ def main():
183
  with open(filepath, "wb") as temp_file:
184
  temp_file.write(uploaded_file.read())
185
  with col1:
186
- input_text, preproc_text_length = preproc_count(filepath, skipfirst, skiplast)
 
 
187
  st.info(
188
  "Uploaded PDF  |  Number of words: "
189
  f"{preproc_text_length:,}"
@@ -192,7 +218,9 @@ def main():
192
  with col2:
193
  start = time.time()
194
  with st.spinner("Summarizing..."):
195
- summary = llm_pipeline(tokenizer, base_model, input_text, model_source)
 
 
196
  postproc_text_length = postproc_count(summary)
197
  end = time.time()
198
  duration = end - start
@@ -203,7 +231,36 @@ def main():
203
  + "  |  Summarization time: "
204
  f"{duration:.0f}" + " seconds"
205
  )
206
- st.success(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
 
209
  st.markdown(
@@ -215,7 +272,7 @@ div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
215
  div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
216
  margin-bottom: -15px;
217
  }
218
- div[class*="stCheckbox"] > label {
219
  margin-bottom: -15px;
220
  }
221
  body > a {
 
 
1
  import base64
2
  from langchain.chains.summarize import load_summarize_chain
3
  from langchain.docstore.document import Document
4
  from langchain.document_loaders.pdf import PyMuPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from PyPDF2 import PdfReader
7
+ import re
8
  import streamlit as st
9
+ import sys
10
  import time
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
13
+ from transformers import pipeline
14
 
15
  # notes
16
  # https://huggingface.co/docs/transformers/pad_truncation
 
20
  def file_preprocessing(file, skipfirst, skiplast):
21
  loader = PyMuPDFLoader(file)
22
  pages = loader.load_and_split()
 
 
 
 
 
 
 
 
23
  # skip page(s)
24
  if (skipfirst == 1) & (skiplast == 0):
25
  del pages[0]
 
30
  del pages[-1]
31
  else:
32
  pages = pages
33
+ # https://stackoverflow.com/questions/76431655/langchain-pypdfloader
34
+ content = ""
35
+ for page in pages:
36
+ content = content + page.page_content
37
+ content = re.sub("-\n", "", content)
38
+ print("\n###### New article ######\n")
39
+ print("Input text:\n")
40
+ print(content)
41
+ print("\nChunking...")
42
  text_splitter = RecursiveCharacterTextSplitter(
43
  chunk_size=1000, # number of characters
44
  chunk_overlap=100,
 
46
  separators=["\n\n", "\n", " ", ""], # default list
47
  )
48
  # https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
49
+ texts = text_splitter.split_text(content)
50
+ print("Number of tokens: " + str(len(texts)))
51
+ print("\nFirst three tokens:\n")
 
52
  print(texts[0])
53
  print("")
54
  print(texts[1])
 
57
  print("")
58
  final_texts = ""
59
  for text in texts:
60
+ final_texts = final_texts + text
61
+ return texts, final_texts
62
 
63
 
64
  # function to count words in the input
65
  def preproc_count(filepath, skipfirst, skiplast):
66
+ texts, input_text = file_preprocessing(filepath, skipfirst, skiplast)
67
+ input_text = input_text.replace("-", "")
68
+ text_length = len(re.findall(r"\w+", input_text))
69
  print("Input word count: " f"{text_length:,}")
70
+ return texts, input_text, text_length
71
+
72
+
73
+ # function to covert (bart) summary to sentence case
74
+ def convert_to_sentence_case(text):
75
+ sentences = re.split(r"(?<=[.!?])\s+", text)
76
+ formatted_sentences = [sentence.capitalize() for sentence in sentences]
77
+ return " ".join(formatted_sentences)
78
 
79
 
80
  # llm pipeline
 
83
  "summarization",
84
  model=base_model,
85
  tokenizer=tokenizer,
86
+ max_length=300,
87
+ min_length=200,
88
+ truncation=True,
89
  )
90
+ print("Model source: %s" % (model_source))
91
  print("Summarizing...")
92
  result = pipe_sum(input_text)
93
  summary = result[0]["summary_text"]
94
+ print("Summarization finished\n")
95
+ print("Summary text:\n")
96
+ print(summary)
97
+ print("")
98
  return summary
99
 
100
 
101
  # function to count words in the summary
102
  def postproc_count(summary):
103
+ text_length = len(re.findall(r"\w+", summary))
104
  print("Summary word count: " f"{text_length:,}")
105
  return text_length
106
 
107
 
108
+ # function to clean summary text
109
+ def clean_summary_text(summary):
110
+ # remove whitespace
111
+ summary_clean_1 = summary.strip()
112
+ # remove spaces before punctuation (bart)
113
+ summary_clean_2 = re.sub(r'\s([,.():;?!"](?:\s|$))', r"\1", summary_clean_1)
114
+ # convert to sentence case
115
+ summary_clean_3 = convert_to_sentence_case(summary_clean_2)
116
+ return summary_clean_3
117
+
118
+
119
+ @st.cache_data(ttl=60 * 60)
120
  # function to display the PDF
121
  def displayPDF(file):
122
  with open(file, "rb") as f:
 
138
  st.subheader("Options")
139
  col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
140
  with col1:
141
+ model_source_names = ["Cached model", "Download model"]
142
+ model_source = st.radio(
143
+ "For development:",
144
+ model_source_names,
145
+ help="Defaults to a cached model; downloading will take longer",
146
+ )
147
  with col2:
148
  model_names = [
149
  "T5-Small",
150
  "BART",
151
  ]
152
+ selected_model = st.radio(
153
+ "Select a model to use:",
154
+ model_names,
155
+ help="Defauls to T5-Small as it summarizes better than BART",
156
+ )
157
  if selected_model == "BART":
158
  checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
159
  tokenizer = AutoTokenizer.from_pretrained(
160
  checkpoint,
161
  truncation=True,
 
162
  model_max_length=1000,
163
+ trust_remote_code=True,
164
  )
165
  if model_source == "Download model":
166
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
167
+ checkpoint,
168
+ torch_dtype=torch.float32,
169
+ trust_remote_code=True,
170
  )
171
+ else:
172
  base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
173
  else:
174
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
 
176
  checkpoint,
177
  truncation=True,
178
  legacy=False,
179
+ model_max_length=1000,
180
  )
181
  if model_source == "Download model":
182
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
183
+ checkpoint,
184
+ torch_dtype=torch.float32,
185
  )
186
  else:
187
  base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
188
  with col3:
189
  st.write("Skip any pages?")
190
+ skipfirst = st.checkbox(
191
+ "Skip first page", help="Select if your PDF has a cover page"
192
+ )
193
  skiplast = st.checkbox("Skip last page")
194
  with col4:
195
  st.write("Background information (links open in a new window)")
196
  st.write(
197
  "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
198
+ "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
199
  )
200
  st.write(
201
  "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
202
+ "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
203
  )
204
  if st.button("Summarize"):
205
  col1, col2 = st.columns(2)
 
207
  with open(filepath, "wb") as temp_file:
208
  temp_file.write(uploaded_file.read())
209
  with col1:
210
+ texts, input_text, preproc_text_length = preproc_count(
211
+ filepath, skipfirst, skiplast
212
+ )
213
  st.info(
214
  "Uploaded PDF&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
215
  f"{preproc_text_length:,}"
 
218
  with col2:
219
  start = time.time()
220
  with st.spinner("Summarizing..."):
221
+ summary = llm_pipeline(
222
+ tokenizer, base_model, input_text, model_source
223
+ )
224
  postproc_text_length = postproc_count(summary)
225
  end = time.time()
226
  duration = end - start
 
231
  + "&nbsp;&nbsp;|&nbsp;&nbsp;Summarization time: "
232
  f"{duration:.0f}" + " seconds"
233
  )
234
+ if selected_model == "BART":
235
+ summary_cleaned = clean_summary_text(summary)
236
+ st.success(summary_cleaned)
237
+ with st.expander("Raw output"):
238
+ st.write(summary)
239
+ else:
240
+ st.success(summary)
241
+ col1 = st.columns(1)
242
+ url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
243
+ st.info("Additional information")
244
+ st.write("")
245
+ st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
246
+ st.write("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_size=1000")
247
+ st.write(
248
+ "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_overlap=100"
249
+ )
250
+ st.write(
251
+ "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;length_function=len"
252
+ )
253
+ st.write("")
254
+ st.write("Number of tokens generated: " + str(len(texts)))
255
+ st.write("")
256
+ st.write("First three tokens:")
257
+ st.write("")
258
+ st.write(texts[0])
259
+ st.write("")
260
+ st.write(texts[1])
261
+ st.write("")
262
+ st.write(texts[2])
263
+ st.write("")
264
 
265
 
266
  st.markdown(
 
272
  div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
273
  margin-bottom: -15px;
274
  }
275
+ div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
276
  margin-bottom: -15px;
277
  }
278
  body > a {