shoukaku commited on
Commit
e319ff3
1 Parent(s): f9603b8

Upload summarizer.py

Browse files
Files changed (1) hide show
  1. skk/summarizer.py +215 -0
skk/summarizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import transformers as hf
4
+ import numpy as np
5
+
6
+ class LitModel(pl.LightningModule):
7
+ ''' pytorch-lightning model '''
8
+
9
+ def __init__(self, model, tokenizer, learning_rate = 5e-5):
10
+ super().__init__()
11
+ self.model = model
12
+ self.tokenizer = tokenizer
13
+ self.learning_rate = learning_rate
14
+
15
+ def freeze_embeds(self):
16
+ ''' freeze the positional embedding parameters of the model '''
17
+ freeze_params(self.model.model.shared)
18
+ for _ in [self.model.model.encoder, self.model.model.decoder]:
19
+ freeze_params(_.embed_positions)
20
+ freeze_params(_.embed_tokens)
21
+
22
+ def forward(self, input_ids, **kwargs):
23
+ return self.model(input_ids, **kwargs)
24
+
25
+ def configure_optimizers(self):
26
+ optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
27
+ return optimizer
28
+
29
+ def training_step(self, batch, batch_idx):
30
+ # load the data into variables
31
+ src_ids, src_mask = batch[0], batch[1]
32
+ target_ids = batch[2]
33
+
34
+ # shift the decoder tokens right
35
+ decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id)
36
+
37
+ # run the model and get the logits
38
+ outputs = self(
39
+ src_ids,
40
+ attention_mask = src_mask,
41
+ decoder_input_ids = decoder_input_ids,
42
+ use_cache = False
43
+ )
44
+ logits = outputs[0]
45
+
46
+ # create the loss function
47
+ f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
48
+
49
+ # calculate the loss on the unshifted tokens
50
+ loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1))
51
+
52
+ return {'loss': loss}
53
+
54
+ def validation_step(self, batch, batch_idx):
55
+ src_ids, src_mask = batch[0], batch[1]
56
+ target_ids = batch[2]
57
+ decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id)
58
+ outputs = self(
59
+ src_ids,
60
+ attention_mask = src_mask,
61
+ decoder_input_ids = decoder_input_ids,
62
+ use_cache = False
63
+ )
64
+ logits = outputs[0]
65
+ f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
66
+ loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1))
67
+
68
+ self.log('loss', torch.tensor([loss]))
69
+
70
+ return {'loss': loss}
71
+
72
+ def generate(self, text, min_length = 40, max_length = 256, eval_beams = 4, early_stopping = True):
73
+ ''' generate text '''
74
+ # generated = self.model.generate(
75
+ # text,
76
+ # min_length = min_length,
77
+ # max_length = max_length,
78
+ # num_beams = eval_beams,
79
+ # early_stopping = early_stopping
80
+ # )
81
+ generated = self.model.generate(
82
+ text['input_ids'],
83
+ attention_mask = text['attention_mask'],
84
+ use_cache = True,
85
+ decoder_start_token_id = self.tokenizer.pad_token_id,
86
+ min_length = min_length,
87
+ max_length = max_length,
88
+ num_beams = eval_beams,
89
+ early_stopping = early_stopping
90
+ )
91
+ return [self.tokenizer.decode(
92
+ w,
93
+ skip_special_tokens = True,
94
+ clean_up_tokenization_spaces = True
95
+ ) for w in generated]
96
+
97
+ def freeze_params(model):
98
+ ''' freeze layers of model for faster training '''
99
+ for layer in model.parameters():
100
+ layer.requires_grade = False
101
+
102
+ class SummaryDataModule(pl.LightningDataModule):
103
+ ''' pytorch-lightning dataloading module '''
104
+
105
+ def __init__(self, tokenizer, dataframe, batch_size, num_examples = 20000):
106
+ super().__init__()
107
+ self.tokenizer = tokenizer
108
+ self.dataframe = dataframe
109
+ self.batch_size = batch_size
110
+ self.num_examples = num_examples
111
+
112
+ def prepare_data(self, split = [0.6, 0.2, 0.2]):
113
+ ''' loads and splits data '''
114
+ self.data = self.dataframe[:self.num_examples]
115
+ self.train, self.validate, self.test = np.split(
116
+ self.data.sample(frac = 1),
117
+ [
118
+ int(split[0] * len(self.data)),
119
+ int(sum([split[i] for i in range(2)]) * len(self.data))
120
+ ]
121
+ )
122
+
123
+ def setup(self, stage):
124
+ self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
125
+ self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
126
+ self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])
127
+
128
+ def train_dataloader(self):
129
+ dataset = torch.utils.data.TensorDataset(
130
+ self.train['input_ids'],
131
+ self.train['attention_mask'],
132
+ self.train['labels']
133
+ )
134
+ train_data = torch.utils.data.DataLoader(
135
+ dataset,
136
+ sampler = torch.utils.data.RandomSampler(dataset),
137
+ batch_size = self.batch_size
138
+ )
139
+ return train_data
140
+
141
+ def val_dataloader(self):
142
+ dataset = torch.utils.data.TensorDataset(
143
+ self.validate['input_ids'],
144
+ self.validate['attention_mask'],
145
+ self.validate['labels']
146
+ )
147
+ val_data = torch.utils.data.DataLoader(
148
+ dataset,
149
+ batch_size = self.batch_size
150
+ )
151
+ return val_data
152
+
153
+ def test_dataloader(self):
154
+ dataset = torch.utils.data.TensorDataset(
155
+ self.test['input_ids'],
156
+ self.test['attention_mask'],
157
+ self.test['labels']
158
+ )
159
+ test_data = torch.utils.data.DataLoader(
160
+ dataset,
161
+ batch_size = self.batch_size
162
+ )
163
+ return test_data
164
+
165
+ def shift_tokens_right(input_ids, pad_token_id):
166
+ prev_output_tokens = input_ids.clone()
167
+ index_of_eos = (input_ids.ne(pad_token_id).sum(dim = 1) - 1).unsqueeze(-1)
168
+ prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
169
+ prev_output_tokens[:, 1:] = input_ids[:, :-1]
170
+ return prev_output_tokens
171
+
172
+ def encode_sentences(tokenizer, source_sentences, target_sentences, max_length = 128, pad_to_max_length = True, return_tensors = 'pt'):
173
+ input_ids = []
174
+ attention_masks = []
175
+ target_ids = []
176
+ tokenized_sentences = {}
177
+
178
+ for s in source_sentences:
179
+ encoded_dict = tokenizer(
180
+ s,
181
+ max_length = max_length,
182
+ padding = 'max_length' if pad_to_max_length else None,
183
+ truncation = True,
184
+ return_tensors = return_tensors,
185
+ add_prefix_space = True
186
+ )
187
+ input_ids.append(encoded_dict['input_ids'])
188
+ attention_masks.append(encoded_dict['attention_mask'])
189
+
190
+ input_ids = torch.cat(input_ids, dim = 0)
191
+ attention_masks = torch.cat(attention_masks, dim = 0)
192
+
193
+ for s in target_sentences:
194
+ encoded_dict = tokenizer(
195
+ s,
196
+ max_length = max_length,
197
+ padding = 'max_length' if pad_to_max_length else None,
198
+ truncation = True,
199
+ return_tensors = return_tensors,
200
+ add_prefix_space = True
201
+ )
202
+ target_ids.append(encoded_dict['input_ids'])
203
+
204
+ target_ids = torch.cat(target_ids, dim = 0)
205
+
206
+ batch = {
207
+ 'input_ids': input_ids,
208
+ 'attention_mask': attention_masks,
209
+ 'labels': target_ids
210
+ }
211
+
212
+ return batch
213
+
214
+ tokenizer = hf.BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6', add_prefix_space = True)
215
+ base_model = hf.BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')