Peter commited on
Commit
4a607b7
1 Parent(s): 588689f

:sparkles: add summarization fns

Browse files
Files changed (1) hide show
  1. summarize.py +126 -0
summarize.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm.auto import tqdm
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+
5
+ def load_model_and_tokenizer(model_name):
6
+ """
7
+ load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
8
+
9
+ Args:
10
+ model_name (str): the name of the model to load
11
+ Returns:
12
+ AutoModelForSeq2SeqLM: the model
13
+ AutoTokenizer: the tokenizer
14
+ """
15
+
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(
17
+ model_name,
18
+ use_cache=False,
19
+ )
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = model.to("cuda") if torch.cuda.is_available() else model
22
+ return model, tokenizer
23
+
24
+ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
25
+ """
26
+ summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
27
+
28
+ Args:
29
+ ids (): the batch of ids
30
+ mask (): the attention mask for the batch
31
+ model (): the model to use for summarization
32
+ tokenizer (): the tokenizer to use for summarization
33
+
34
+ Returns:
35
+ str: the summary of the batch
36
+ """
37
+
38
+
39
+ ids = ids[None, :]
40
+ mask = mask[None, :]
41
+
42
+ input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
43
+ attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
44
+
45
+
46
+ attention_mask = mask.to("cuda")
47
+ global_attention_mask = torch.zeros_like(attention_mask)
48
+ # put global attention on <s> token
49
+ global_attention_mask[:, 0] = 1
50
+
51
+ summary_pred_ids = model.generate(
52
+ input_ids,
53
+ attention_mask=attention_mask,
54
+ global_attention_mask=global_attention_mask,
55
+ output_scores=True,
56
+ return_dict_in_generate=True,
57
+ **kwargs
58
+ )
59
+ summary = tokenizer.batch_decode(
60
+ summary_pred_ids.sequences,
61
+ skip_special_tokens=True,
62
+ remove_invalid_values=True,
63
+ )
64
+ score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
65
+
66
+ return summary, score
67
+
68
+ def summarize_via_tokenbatches(
69
+ input_text:str,
70
+ model, tokenizer,
71
+ batch_length=2048,
72
+ batch_stride=16,
73
+ **kwargs,
74
+ ):
75
+ """
76
+ summarize_via_tokenbatches - a function that takes a string and returns a summary
77
+
78
+ Args:
79
+ input_text (str): the text to summarize
80
+ model (): the model to use for summarization
81
+ tokenizer (): the tokenizer to use for summarization
82
+ batch_length (int, optional): the length of each batch. Defaults to 2048.
83
+ batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
84
+
85
+ Returns:
86
+ str: the summary
87
+ """
88
+
89
+ encoded_input = tokenizer(
90
+ input_text,
91
+ padding='max_length',
92
+ truncation=True,
93
+ max_length=batch_length,
94
+ stride=batch_stride,
95
+ return_overflowing_tokens=True,
96
+ add_special_tokens =False,
97
+ return_tensors='pt',
98
+ )
99
+
100
+ in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
101
+ gen_summaries = []
102
+
103
+ pbar = tqdm(total=len(in_id_arr))
104
+
105
+ for _id, _mask in zip(in_id_arr, att_arr):
106
+
107
+ result, score = summarize_and_score(
108
+ ids=_id,
109
+ mask=_mask,
110
+ model=model,
111
+ tokenizer=tokenizer,
112
+ **kwargs,
113
+ )
114
+ score = round(float(score),4)
115
+ _sum = {
116
+ "input_tokens":_id,
117
+ "summary":result,
118
+ "summary_score":score,
119
+ }
120
+ gen_summaries.append(_sum)
121
+ print(f"\t{result[0]}\nScore:\t{score}")
122
+ pbar.update()
123
+
124
+ pbar.close()
125
+
126
+ return gen_summaries