pszemraj commited on
Commit
c006617
โ€ข
1 Parent(s): 2956200

๐Ÿ”Š improve logging and docs

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. summarize.py +20 -20
summarize.py CHANGED
@@ -1,25 +1,22 @@
1
  import logging
2
 
 
 
3
  import torch
4
  from tqdm.auto import tqdm
5
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
 
7
 
8
- def load_model_and_tokenizer(model_name):
9
  """
10
- load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
11
 
12
- Args:
13
- model_name (str): the name of the model to load
14
- Returns:
15
- AutoModelForSeq2SeqLM: the model
16
- AutoTokenizer: the tokenizer
17
  """
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
- # low_cpu_mem_usage=True,
22
- # use_cache=False,
23
  ).to(device)
24
  model = model.eval()
25
 
@@ -32,7 +29,7 @@ def load_model_and_tokenizer(model_name):
32
 
33
  def summarize_and_score(
34
  ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
35
- ):
36
  """
37
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
38
 
@@ -42,9 +39,9 @@ def summarize_and_score(
42
  model (): the model to use for summarization
43
  tokenizer (): the tokenizer to use for summarization
44
  is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
45
-
46
  Returns:
47
- str: the summary of the batch
48
  """
49
 
50
  ids = ids[None, :]
@@ -91,25 +88,29 @@ def summarize_via_tokenbatches(
91
  batch_length=2048,
92
  batch_stride=16,
93
  **kwargs,
94
- ):
95
  """
96
- summarize_via_tokenbatches - a function that takes a string and returns a summary
97
 
98
  Args:
99
  input_text (str): the text to summarize
100
- model (): the model to use for summarizationz
101
  tokenizer (): the tokenizer to use for summarization
102
  batch_length (int, optional): the length of each batch. Defaults to 2048.
103
  batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
104
 
105
  Returns:
106
- str: the summary
107
  """
 
 
108
  # log all input parameters
109
  if batch_length < 512:
110
  batch_length = 512
111
- print("WARNING: batch_length was set to 512")
112
- print(
 
 
113
  f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
114
  )
115
  encoded_input = tokenizer(
@@ -129,7 +130,6 @@ def summarize_via_tokenbatches(
129
  pbar = tqdm(total=len(in_id_arr))
130
 
131
  for _id, _mask in zip(in_id_arr, att_arr):
132
-
133
  result, score = summarize_and_score(
134
  ids=_id,
135
  mask=_mask,
@@ -144,7 +144,7 @@ def summarize_via_tokenbatches(
144
  "summary_score": score,
145
  }
146
  gen_summaries.append(_sum)
147
- print(f"\t{result[0]}\nScore:\t{score}")
148
  pbar.update()
149
 
150
  pbar.close()
 
1
  import logging
2
 
3
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
4
+
5
  import torch
6
  from tqdm.auto import tqdm
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
 
9
 
10
+ def load_model_and_tokenizer(model_name: str) -> tuple:
11
  """
12
+ load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
13
 
14
+ :param str model_name: the model name/ID on the hub
15
+ :return tuple: a tuple containing the model and tokenizer
 
 
 
16
  """
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model = AutoModelForSeq2SeqLM.from_pretrained(
19
  model_name,
 
 
20
  ).to(device)
21
  model = model.eval()
22
 
 
29
 
30
  def summarize_and_score(
31
  ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
32
+ ) -> tuple:
33
  """
34
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
35
 
 
39
  model (): the model to use for summarization
40
  tokenizer (): the tokenizer to use for summarization
41
  is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
42
+ **kwargs: any additional arguments to pass to the model
43
  Returns:
44
+ tuple (str, float): the summary, the score for the summary
45
  """
46
 
47
  ids = ids[None, :]
 
88
  batch_length=2048,
89
  batch_stride=16,
90
  **kwargs,
91
+ ) -> list:
92
  """
93
+ summarize_via_tokenbatches - summarize a long string via batches of tokens
94
 
95
  Args:
96
  input_text (str): the text to summarize
97
+ model (): the model to use for summarization
98
  tokenizer (): the tokenizer to use for summarization
99
  batch_length (int, optional): the length of each batch. Defaults to 2048.
100
  batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
101
 
102
  Returns:
103
+ list: a list of dictionaries containing the input tokens, the summary, and the summary score
104
  """
105
+
106
+ logger = logging.getLogger(__name__)
107
  # log all input parameters
108
  if batch_length < 512:
109
  batch_length = 512
110
+ logger.warning(
111
+ f"batch_length must be at least 512. Setting batch_length to {batch_length}"
112
+ )
113
+ logger.info(
114
  f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
115
  )
116
  encoded_input = tokenizer(
 
130
  pbar = tqdm(total=len(in_id_arr))
131
 
132
  for _id, _mask in zip(in_id_arr, att_arr):
 
133
  result, score = summarize_and_score(
134
  ids=_id,
135
  mask=_mask,
 
144
  "summary_score": score,
145
  }
146
  gen_summaries.append(_sum)
147
+ logger.info(f"\t{result[0]}\nScore:\t{score}")
148
  pbar.update()
149
 
150
  pbar.close()