pszemraj commited on
Commit
0cef1e2
1 Parent(s): b542f3a

✨ mai improvements

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (4) hide show
  1. app.py +99 -41
  2. requirements.txt +2 -2
  3. summarize.py +36 -22
  4. utils.py +14 -0
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import logging
2
  import random
3
  import re
@@ -6,6 +10,7 @@ from pathlib import Path
6
 
7
  import gradio as gr
8
  import nltk
 
9
  from cleantext import clean
10
 
11
  from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
@@ -13,22 +18,62 @@ from utils import load_example_filenames, truncate_word_count
13
 
14
  _here = Path(__file__).parent
15
 
16
- nltk.download("stopwords") # TODO=find where this requirement originates from
17
 
18
  logging.basicConfig(
19
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def proc_submission(
24
  input_text: str,
25
- model_size: str,
26
- num_beams,
27
- token_batch_length,
28
- length_penalty,
29
- repetition_penalty,
30
- no_repeat_ngram_size,
31
- max_input_length: int = 1024,
32
  ):
33
  """
34
  proc_submission - a helper function for the gradio module to process submissions
@@ -41,12 +86,14 @@ def proc_submission(
41
  length_penalty (float): the length penalty to use
42
  repetition_penalty (float): the repetition penalty to use
43
  no_repeat_ngram_size (int): the no-repeat ngram size to use
44
- max_input_length (int, optional): the maximum input length to use. Defaults to 1024.
45
 
46
  Returns:
47
  str in HTML format, string of the summary, str of score
48
  """
49
 
 
 
50
  settings = {
51
  "length_penalty": float(length_penalty),
52
  "repetition_penalty": float(repetition_penalty),
@@ -58,14 +105,19 @@ def proc_submission(
58
  "early_stopping": True,
59
  "do_sample": False,
60
  }
 
 
 
 
 
 
61
  st = time.perf_counter()
62
  history = {}
63
  clean_text = clean(input_text, lower=False)
64
- max_input_length = 2048 if model_size == "base" else max_input_length
65
  processed = truncate_word_count(clean_text, max_input_length)
66
 
67
  if processed["was_truncated"]:
68
- tr_in = processed["truncated_text"]
69
  # create elaborate HTML warning
70
  input_wc = re.split(r"\s+", input_text)
71
  msg = f"""
@@ -77,7 +129,7 @@ def proc_submission(
77
  logging.warning(msg)
78
  history["WARNING"] = msg
79
  else:
80
- tr_in = input_text
81
  msg = None
82
 
83
  if len(input_text) < 50:
@@ -95,24 +147,25 @@ def proc_submission(
95
 
96
  return msg, "", []
97
 
98
- _summaries = summarize_via_tokenbatches(
99
- tr_in,
100
- model_sm if "base" in model_size.lower() else model,
101
- tokenizer_sm if "base" in model_size.lower() else tokenizer,
102
- batch_length=token_batch_length,
103
  **settings,
104
  )
105
- sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
 
 
106
  sum_scores = [
107
- f" - Section {i}: {round(s['summary_score'],4)}"
108
- for i, s in enumerate(_summaries)
109
  ]
110
 
111
  sum_text_out = "\n".join(sum_text)
112
  history["Summary Scores"] = "<br><br>"
113
  scores_out = "\n".join(sum_scores)
114
  rt = round((time.perf_counter() - st) / 60, 2)
115
- print(f"Runtime: {rt} minutes")
116
  html = ""
117
  html += f"<p>Runtime: {rt} minutes on CPU</p>"
118
  if msg is not None:
@@ -169,36 +222,38 @@ def load_uploaded_file(file_obj):
169
 
170
 
171
  if __name__ == "__main__":
172
-
173
- model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
174
- model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
175
-
176
  name_to_path = load_example_filenames(_here / "examples")
177
  logging.info(f"Loaded {len(name_to_path)} examples")
178
- demo = gr.Blocks()
 
 
179
  _examples = list(name_to_path.keys())
180
  with demo:
181
-
182
  gr.Markdown("# Long-Form Summarization: LED & BookSum")
183
  gr.Markdown(
184
  "LED models ([model card](https://huggingface.co/pszemraj/led-large-book-summary)) fine-tuned to summarize long-form text. A [space with other models can be found here](https://huggingface.co/spaces/pszemraj/document-summarization)"
185
  )
186
  with gr.Column():
187
-
188
  gr.Markdown("## Load Inputs & Select Parameters")
189
  gr.Markdown(
190
  "Enter or upload text below, and it will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). "
191
  )
192
  with gr.Row():
193
- model_size = gr.Radio(
194
- choices=["base", "large"], label="Model Variant", value="large"
 
 
195
  )
196
  num_beams = gr.Radio(
197
  choices=[2, 3, 4],
198
  label="Beam Search: # of Beams",
199
  value=2,
200
  )
201
- gr.Markdown("Load a a .txt - example or your own (_You may find [this OCR space](https://huggingface.co/spaces/pszemraj/pdf-ocr) useful_)")
 
 
202
  with gr.Row():
203
  example_name = gr.Dropdown(
204
  _examples,
@@ -213,7 +268,8 @@ if __name__ == "__main__":
213
  with gr.Row():
214
  input_text = gr.Textbox(
215
  lines=4,
216
- label="Input Text (for summarization)",
 
217
  placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
218
  )
219
  with gr.Column():
@@ -250,11 +306,11 @@ if __name__ == "__main__":
250
  with gr.Column():
251
  gr.Markdown("### Advanced Settings")
252
  with gr.Row():
253
- length_penalty = gr.inputs.Slider(
254
  minimum=0.5,
255
  maximum=1.0,
256
  label="length penalty",
257
- default=0.7,
258
  step=0.05,
259
  )
260
  token_batch_length = gr.Radio(
@@ -264,11 +320,11 @@ if __name__ == "__main__":
264
  )
265
 
266
  with gr.Row():
267
- repetition_penalty = gr.inputs.Slider(
268
  minimum=1.0,
269
  maximum=5.0,
270
  label="repetition penalty",
271
- default=3.5,
272
  step=0.1,
273
  )
274
  no_repeat_ngram_size = gr.Radio(
@@ -282,10 +338,10 @@ if __name__ == "__main__":
282
  "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
283
  )
284
  gr.Markdown(
285
- "- The two most important parameters-empirically-are the `num_beams` and `token_batch_length`. "
286
  )
287
  gr.Markdown(
288
- "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a Colab notebook for a tutorial."
289
  )
290
  gr.Markdown("---")
291
 
@@ -301,7 +357,7 @@ if __name__ == "__main__":
301
  fn=proc_submission,
302
  inputs=[
303
  input_text,
304
- model_size,
305
  num_beams,
306
  token_batch_length,
307
  length_penalty,
@@ -311,4 +367,6 @@ if __name__ == "__main__":
311
  outputs=[output_text, summary_text, summary_scores],
312
  )
313
 
314
- demo.launch(enable_queue=True, share=True)
 
 
 
1
+ """
2
+ app.py - the main application file for the gradio app
3
+ """
4
+ import gc
5
  import logging
6
  import random
7
  import re
 
10
 
11
  import gradio as gr
12
  import nltk
13
+ import torch
14
  from cleantext import clean
15
 
16
  from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
 
18
 
19
  _here = Path(__file__).parent
20
 
21
+ nltk.download("stopwords", quiet=True)
22
 
23
  logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - [%(levelname)s] %(name)s: %(message)s"
25
  )
26
 
27
+ MODEL_OPTIONS = [
28
+ "pszemraj/led-large-book-summary",
29
+ "pszemraj/led-large-book-summary-continued",
30
+ "pszemraj/led-base-book-summary",
31
+ ]
32
+
33
+
34
+ def predict(
35
+ input_text: str,
36
+ model_name: str,
37
+ token_batch_length: int = 2048,
38
+ empty_cache: bool = True,
39
+ **settings,
40
+ ) -> list:
41
+ """
42
+ predict - helper fn to support multiple models for summarization at once
43
+ :param str input_text: the input text to summarize
44
+ :param str model_name: model name to use
45
+ :param int token_batch_length: the length of the token batches to use
46
+ :param bool empty_cache: whether to empty the cache before loading a new= model
47
+ :return: list of dicts with keys "summary" and "score"
48
+ """
49
+ if torch.cuda.is_available() and empty_cache:
50
+ torch.cuda.empty_cache()
51
+
52
+ model, tokenizer = load_model_and_tokenizer(model_name)
53
+ summaries = summarize_via_tokenbatches(
54
+ input_text,
55
+ model,
56
+ tokenizer,
57
+ batch_length=token_batch_length,
58
+ **settings,
59
+ )
60
+
61
+ del model
62
+ del tokenizer
63
+ gc.collect()
64
+
65
+ return summaries
66
+
67
 
68
  def proc_submission(
69
  input_text: str,
70
+ model_name: str,
71
+ num_beams: int,
72
+ token_batch_length: int,
73
+ length_penalty: float,
74
+ repetition_penalty: float,
75
+ no_repeat_ngram_size: int,
76
+ max_input_length: int = 2560,
77
  ):
78
  """
79
  proc_submission - a helper function for the gradio module to process submissions
 
86
  length_penalty (float): the length penalty to use
87
  repetition_penalty (float): the repetition penalty to use
88
  no_repeat_ngram_size (int): the no-repeat ngram size to use
89
+ max_input_length (int, optional): the maximum input length to use. Defaults to 2560.
90
 
91
  Returns:
92
  str in HTML format, string of the summary, str of score
93
  """
94
 
95
+ logger = logging.getLogger(__name__)
96
+ logger.info("Processing submission")
97
  settings = {
98
  "length_penalty": float(length_penalty),
99
  "repetition_penalty": float(repetition_penalty),
 
105
  "early_stopping": True,
106
  "do_sample": False,
107
  }
108
+
109
+ if "base" in model_name:
110
+ logger.info("Updating max_input_length to for base model")
111
+ max_input_length = 4096
112
+
113
+ logger.info(f"max_input_length: {max_input_length}")
114
  st = time.perf_counter()
115
  history = {}
116
  clean_text = clean(input_text, lower=False)
 
117
  processed = truncate_word_count(clean_text, max_input_length)
118
 
119
  if processed["was_truncated"]:
120
+ truncated_input = processed["truncated_text"]
121
  # create elaborate HTML warning
122
  input_wc = re.split(r"\s+", input_text)
123
  msg = f"""
 
129
  logging.warning(msg)
130
  history["WARNING"] = msg
131
  else:
132
+ truncated_input = input_text
133
  msg = None
134
 
135
  if len(input_text) < 50:
 
147
 
148
  return msg, "", []
149
 
150
+ _summaries = predict(
151
+ input_text=truncated_input,
152
+ model_name=model_name,
153
+ token_batch_length=token_batch_length,
 
154
  **settings,
155
  )
156
+ sum_text = [
157
+ f"\nBatch {i}:\n\t" + s["summary"][0] for i, s in enumerate(_summaries, start=1)
158
+ ]
159
  sum_scores = [
160
+ f"\n- Batch {i}:\n\t{round(s['summary_score'],4)}"
161
+ for i, s in enumerate(_summaries, start=1)
162
  ]
163
 
164
  sum_text_out = "\n".join(sum_text)
165
  history["Summary Scores"] = "<br><br>"
166
  scores_out = "\n".join(sum_scores)
167
  rt = round((time.perf_counter() - st) / 60, 2)
168
+ logger.info(f"Runtime: {rt} minutes")
169
  html = ""
170
  html += f"<p>Runtime: {rt} minutes on CPU</p>"
171
  if msg is not None:
 
222
 
223
 
224
  if __name__ == "__main__":
225
+ logger = logging.getLogger(__name__)
226
+ logger.info("Starting up app")
 
 
227
  name_to_path = load_example_filenames(_here / "examples")
228
  logging.info(f"Loaded {len(name_to_path)} examples")
229
+ demo = gr.Blocks(
230
+ title="Summarize Long-Form Text",
231
+ )
232
  _examples = list(name_to_path.keys())
233
  with demo:
 
234
  gr.Markdown("# Long-Form Summarization: LED & BookSum")
235
  gr.Markdown(
236
  "LED models ([model card](https://huggingface.co/pszemraj/led-large-book-summary)) fine-tuned to summarize long-form text. A [space with other models can be found here](https://huggingface.co/spaces/pszemraj/document-summarization)"
237
  )
238
  with gr.Column():
 
239
  gr.Markdown("## Load Inputs & Select Parameters")
240
  gr.Markdown(
241
  "Enter or upload text below, and it will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). "
242
  )
243
  with gr.Row():
244
+ model_name = gr.Dropdown(
245
+ choices=MODEL_OPTIONS,
246
+ value=MODEL_OPTIONS[0],
247
+ label="Model Name",
248
  )
249
  num_beams = gr.Radio(
250
  choices=[2, 3, 4],
251
  label="Beam Search: # of Beams",
252
  value=2,
253
  )
254
+ gr.Markdown(
255
+ "Load a a .txt - example or your own (_You may find [this OCR space](https://huggingface.co/spaces/pszemraj/pdf-ocr) useful_)"
256
+ )
257
  with gr.Row():
258
  example_name = gr.Dropdown(
259
  _examples,
 
268
  with gr.Row():
269
  input_text = gr.Textbox(
270
  lines=4,
271
+ max_lines=12,
272
+ label="Text to Summarize",
273
  placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
274
  )
275
  with gr.Column():
 
306
  with gr.Column():
307
  gr.Markdown("### Advanced Settings")
308
  with gr.Row():
309
+ length_penalty = gr.Slider(
310
  minimum=0.5,
311
  maximum=1.0,
312
  label="length penalty",
313
+ value=0.7,
314
  step=0.05,
315
  )
316
  token_batch_length = gr.Radio(
 
320
  )
321
 
322
  with gr.Row():
323
+ repetition_penalty = gr.Slider(
324
  minimum=1.0,
325
  maximum=5.0,
326
  label="repetition penalty",
327
+ value=3.5,
328
  step=0.1,
329
  )
330
  no_repeat_ngram_size = gr.Radio(
 
338
  "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
339
  )
340
  gr.Markdown(
341
+ "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a Colab notebook for a tutorial."
342
  )
343
  gr.Markdown(
344
+ "- **Update May 1, 2023:** Enabled faster inference times via `use_cache=True`, the number of words the model will processed has been increased! New [test model](https://huggingface.co/pszemraj/led-large-book-summary-continued) as an extension of `led-large-book-summary`."
345
  )
346
  gr.Markdown("---")
347
 
 
357
  fn=proc_submission,
358
  inputs=[
359
  input_text,
360
+ model_name,
361
  num_beams,
362
  token_batch_length,
363
  length_penalty,
 
367
  outputs=[output_text, summary_text, summary_scores],
368
  )
369
 
370
+ demo.launch(
371
+ enable_queue=True,
372
+ )
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- clean-text[gpl]
2
  gradio
3
  natsort
4
  nltk
5
  torch
6
  tqdm
7
  transformers
8
- accelerate
 
1
+ clean-text
2
  gradio
3
  natsort
4
  nltk
5
  torch
6
  tqdm
7
  transformers
8
+ accelerate
summarize.py CHANGED
@@ -1,30 +1,40 @@
1
  import logging
 
2
 
 
 
 
3
  import torch
4
  from tqdm.auto import tqdm
5
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
 
7
 
8
- def load_model_and_tokenizer(model_name):
9
  """
10
- load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
11
-
12
- Args:
13
- model_name (str): the name of the model to load
14
- Returns:
15
- AutoModelForSeq2SeqLM: the model
16
- AutoTokenizer: the tokenizer
17
  """
18
-
 
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
- # low_cpu_mem_usage=True,
22
- # use_cache=False,
23
- )
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = model.to("cuda") if torch.cuda.is_available() else model
26
 
27
- logging.info(f"Loaded model {model_name}")
 
 
 
 
 
 
 
 
 
 
28
  return model, tokenizer
29
 
30
 
@@ -76,6 +86,7 @@ def summarize_via_tokenbatches(
76
  tokenizer,
77
  batch_length=2048,
78
  batch_stride=16,
 
79
  **kwargs,
80
  ):
81
  """
@@ -83,7 +94,7 @@ def summarize_via_tokenbatches(
83
 
84
  Args:
85
  input_text (str): the text to summarize
86
- model (): the model to use for summarizationz
87
  tokenizer (): the tokenizer to use for summarization
88
  batch_length (int, optional): the length of each batch. Defaults to 2048.
89
  batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
@@ -92,12 +103,16 @@ def summarize_via_tokenbatches(
92
  str: the summary
93
  """
94
  # log all input parameters
95
- if batch_length < 512:
96
- batch_length = 512
97
- print("WARNING: batch_length was set to 512")
98
- print(
99
- f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
100
- )
 
 
 
 
101
  encoded_input = tokenizer(
102
  input_text,
103
  padding="max_length",
@@ -115,7 +130,6 @@ def summarize_via_tokenbatches(
115
  pbar = tqdm(total=len(in_id_arr))
116
 
117
  for _id, _mask in zip(in_id_arr, att_arr):
118
-
119
  result, score = summarize_and_score(
120
  ids=_id,
121
  mask=_mask,
 
1
  import logging
2
+ import pprint as pp
3
 
4
+ from utils import validate_pytorch2
5
+
6
+ logging.basicConfig(level=logging.INFO)
7
  import torch
8
  from tqdm.auto import tqdm
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
 
11
 
12
+ def load_model_and_tokenizer(model_name: str) -> tuple:
13
  """
14
+ load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
15
+ :param str model_name: the model name/ID on the hub
16
+ :return tuple: a tuple containing the model and tokenizer
 
 
 
 
17
  """
18
+ logger = logging.getLogger(__name__)
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  model = AutoModelForSeq2SeqLM.from_pretrained(
21
  model_name,
22
+ ).to(device)
23
+ model = model.eval()
24
+
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
26
 
27
+ logger.info(f"Loaded model {model_name} to {device}")
28
+
29
+ if validate_pytorch2():
30
+ try:
31
+ logger.info("Compiling model with Torch 2.0")
32
+ model = torch.compile(model)
33
+ except Exception as e:
34
+ logger.warning(f"Could not compile model with Torch 2.0: {e}")
35
+ else:
36
+ logger.info("Torch 2.0 not detected, skipping compilation")
37
+
38
  return model, tokenizer
39
 
40
 
 
86
  tokenizer,
87
  batch_length=2048,
88
  batch_stride=16,
89
+ min_batch_length: int = 512,
90
  **kwargs,
91
  ):
92
  """
 
94
 
95
  Args:
96
  input_text (str): the text to summarize
97
+ model (): the model to use for summarization
98
  tokenizer (): the tokenizer to use for summarization
99
  batch_length (int, optional): the length of each batch. Defaults to 2048.
100
  batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
 
103
  str: the summary
104
  """
105
  # log all input parameters
106
+ logger = logging.getLogger(__name__)
107
+ # log all input parameters
108
+ if batch_length < min_batch_length:
109
+ logger.warning(
110
+ f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
111
+ )
112
+ batch_length = min_batch_length
113
+
114
+ logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
115
+ logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
116
  encoded_input = tokenizer(
117
  input_text,
118
  padding="max_length",
 
130
  pbar = tqdm(total=len(in_id_arr))
131
 
132
  for _id, _mask in zip(in_id_arr, att_arr):
 
133
  result, score = summarize_and_score(
134
  ids=_id,
135
  mask=_mask,
utils.py CHANGED
@@ -2,12 +2,26 @@
2
  utils.py - Utility functions for the project.
3
  """
4
 
 
5
  import re
6
  from pathlib import Path
7
 
 
 
 
 
 
8
  from natsort import natsorted
9
 
10
 
 
 
 
 
 
 
 
 
11
  def truncate_word_count(text, max_words=512):
12
  """
13
  truncate_word_count - a helper function for the gradio module
 
2
  utils.py - Utility functions for the project.
3
  """
4
 
5
+ import logging
6
  import re
7
  from pathlib import Path
8
 
9
+ logging.basicConfig(
10
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
11
+ level=logging.INFO,
12
+ )
13
+ import torch
14
  from natsort import natsorted
15
 
16
 
17
+ def validate_pytorch2(torch_version: str = None):
18
+ torch_version = torch.__version__ if torch_version is None else torch_version
19
+
20
+ pattern = r"^2\.\d+(\.\d+)*"
21
+
22
+ return True if re.match(pattern, torch_version) else False
23
+
24
+
25
  def truncate_word_count(text, max_words=512):
26
  """
27
  truncate_word_count - a helper function for the gradio module