Gagan Bhatia commited on
Commit
b10a55f
1 Parent(s): 0842de0

Update train_model.py

Browse files
Files changed (1) hide show
  1. src/models/train_model.py +441 -0
src/models/train_model.py CHANGED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from datasets import load_metric
7
+ from tqdm.auto import tqdm
8
+ from transformers import (
9
+ AdamW,
10
+ T5ForConditionalGeneration,
11
+ MT5ForConditionalGeneration,
12
+ T5TokenizerFast as T5Tokenizer,
13
+ MT5TokenizerFast as MT5Tokenizer,
14
+ )
15
+ from transformers import AutoTokenizer
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from transformers import AutoModelWithLMHead, AutoTokenizer
18
+ import pytorch_lightning as pl
19
+ from pytorch_lightning.loggers import MLFlowLogger
20
+ from pytorch_lightning import Trainer
21
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
22
+ from pytorch_lightning import LightningDataModule
23
+ from pytorch_lightning import LightningModule
24
+
25
+ torch.cuda.empty_cache()
26
+ pl.seed_everything(42)
27
+
28
+
29
+ class DataModule(Dataset):
30
+ """
31
+ Data Module for pytorch
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ data: pd.DataFrame,
37
+ tokenizer: T5Tokenizer,
38
+ source_max_token_len: int = 512,
39
+ target_max_token_len: int = 512,
40
+ ):
41
+ """
42
+ :param data:
43
+ :param tokenizer:
44
+ :param source_max_token_len:
45
+ :param target_max_token_len:
46
+ """
47
+ self.data = data
48
+ self.target_max_token_len = target_max_token_len
49
+ self.source_max_token_len = source_max_token_len
50
+ self.tokenizer = tokenizer
51
+
52
+ def __len__(self):
53
+ return len(self.data)
54
+
55
+ def __getitem__(self, index: int):
56
+ data_row = self.data.iloc[index]
57
+
58
+ input_encoding = self.tokenizer(
59
+ data_row["input_text"],
60
+ max_length=self.source_max_token_len,
61
+ padding="max_length",
62
+ truncation=True,
63
+ return_attention_mask=True,
64
+ add_special_tokens=True,
65
+ return_tensors="pt",
66
+ )
67
+
68
+ output_encoding = self.tokenizer(
69
+ data_row["output_text"],
70
+ max_length=self.target_max_token_len,
71
+ padding="max_length",
72
+ truncation=True,
73
+ return_attention_mask=True,
74
+ add_special_tokens=True,
75
+ return_tensors="pt",
76
+ )
77
+
78
+ labels = output_encoding["input_ids"]
79
+ labels[
80
+ labels == 0
81
+ ] = -100
82
+
83
+ return dict(
84
+ keywords=data_row["keywords"],
85
+ text=data_row["text"],
86
+ keywords_input_ids=input_encoding["input_ids"].flatten(),
87
+ keywords_attention_mask=input_encoding["attention_mask"].flatten(),
88
+ labels=labels.flatten(),
89
+ labels_attention_mask=output_encoding["attention_mask"].flatten(),
90
+ )
91
+
92
+
93
+ class PLDataModule(LightningDataModule):
94
+ def __init__(
95
+ self,
96
+ train_df: pd.DataFrame,
97
+ test_df: pd.DataFrame,
98
+ tokenizer: T5Tokenizer,
99
+ source_max_token_len: int = 512,
100
+ target_max_token_len: int = 512,
101
+ batch_size: int = 4,
102
+ split: float = 0.1
103
+ ):
104
+ """
105
+ :param data_df:
106
+ :param tokenizer:
107
+ :param source_max_token_len:
108
+ :param target_max_token_len:
109
+ :param batch_size:
110
+ :param split:
111
+ """
112
+ super().__init__()
113
+ self.train_df = train_df
114
+ self.test_df = test_df
115
+ self.split = split
116
+ self.batch_size = batch_size
117
+ self.target_max_token_len = target_max_token_len
118
+ self.source_max_token_len = source_max_token_len
119
+ self.tokenizer = tokenizer
120
+
121
+ def setup(self, stage=None):
122
+ self.train_dataset = DataModule(
123
+ self.train_df,
124
+ self.tokenizer,
125
+ self.source_max_token_len,
126
+ self.target_max_token_len,
127
+ )
128
+ self.test_dataset = DataModule(
129
+ self.test_df,
130
+ self.tokenizer,
131
+ self.source_max_token_len,
132
+ self.target_max_token_len,
133
+ )
134
+
135
+ def train_dataloader(self):
136
+ """ training dataloader """
137
+ return DataLoader(
138
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2
139
+ )
140
+
141
+ def test_dataloader(self):
142
+ """ test dataloader """
143
+ return DataLoader(
144
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
145
+ )
146
+
147
+ def val_dataloader(self):
148
+ """ validation dataloader """
149
+ return DataLoader(
150
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
151
+ )
152
+
153
+
154
+ class LightningModel(LightningModule):
155
+ """ PyTorch Lightning Model class"""
156
+
157
+ def __init__(self, tokenizer, model, output: str = "outputs"):
158
+ """
159
+ initiates a PyTorch Lightning Model
160
+ Args:
161
+ tokenizer : T5 tokenizer
162
+ model : T5 model
163
+ output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
164
+ """
165
+ super().__init__()
166
+ self.model = model
167
+ self.tokenizer = tokenizer
168
+ self.output = output
169
+ # self.val_acc = Accuracy()
170
+ # self.train_acc = Accuracy()
171
+
172
+ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
173
+ """ forward step """
174
+ output = self.model(
175
+ input_ids,
176
+ attention_mask=attention_mask,
177
+ labels=labels,
178
+ decoder_attention_mask=decoder_attention_mask,
179
+ )
180
+
181
+ return output.loss, output.logits
182
+
183
+ def training_step(self, batch, batch_size):
184
+ """ training step """
185
+ input_ids = batch["keywords_input_ids"]
186
+ attention_mask = batch["keywords_attention_mask"]
187
+ labels = batch["labels"]
188
+ labels_attention_mask = batch["labels_attention_mask"]
189
+
190
+ loss, outputs = self(
191
+ input_ids=input_ids,
192
+ attention_mask=attention_mask,
193
+ decoder_attention_mask=labels_attention_mask,
194
+ labels=labels,
195
+ )
196
+ self.log("train_loss", loss, prog_bar=True, logger=True)
197
+ return loss
198
+
199
+ def validation_step(self, batch, batch_size):
200
+ """ validation step """
201
+ input_ids = batch["keywords_input_ids"]
202
+ attention_mask = batch["keywords_attention_mask"]
203
+ labels = batch["labels"]
204
+ labels_attention_mask = batch["labels_attention_mask"]
205
+
206
+ loss, outputs = self(
207
+ input_ids=input_ids,
208
+ attention_mask=attention_mask,
209
+ decoder_attention_mask=labels_attention_mask,
210
+ labels=labels,
211
+ )
212
+ self.log("val_loss", loss, prog_bar=True, logger=True)
213
+ return loss
214
+
215
+ def test_step(self, batch, batch_size):
216
+ """ test step """
217
+ input_ids = batch["keywords_input_ids"]
218
+ attention_mask = batch["keywords_attention_mask"]
219
+ labels = batch["labels"]
220
+ labels_attention_mask = batch["labels_attention_mask"]
221
+
222
+ loss, outputs = self(
223
+ input_ids=input_ids,
224
+ attention_mask=attention_mask,
225
+ decoder_attention_mask=labels_attention_mask,
226
+ labels=labels,
227
+ )
228
+
229
+ self.log("test_loss", loss, prog_bar=True, logger=True)
230
+ return loss
231
+
232
+ def configure_optimizers(self):
233
+ """ configure optimizers """
234
+ model = self.model
235
+ no_decay = ["bias", "LayerNorm.weight"]
236
+ optimizer_grouped_parameters = [
237
+ {
238
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
239
+ "weight_decay": self.hparams.weight_decay,
240
+ },
241
+ {
242
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
243
+ "weight_decay": 0.0,
244
+ },
245
+ ]
246
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
247
+ self.opt = optimizer
248
+ return [optimizer]
249
+
250
+
251
+ class Summarization:
252
+ """ Custom Summarization class """
253
+
254
+ def __init__(self) -> None:
255
+ """ initiates Summarization class """
256
+ pass
257
+
258
+ def from_pretrained(self, model_name="t5-base") -> None:
259
+ """
260
+ loads T5/MT5 Model model for training/finetuning
261
+ Args:
262
+ model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
263
+ """
264
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
265
+ self.model = T5ForConditionalGeneration.from_pretrained(
266
+ f"{model_name}", return_dict=True
267
+ )
268
+
269
+ def train(
270
+ self,
271
+ train_df: pd.DataFrame,
272
+ eval_df: pd.DataFrame,
273
+ source_max_token_len: int = 512,
274
+ target_max_token_len: int = 512,
275
+ batch_size: int = 8,
276
+ max_epochs: int = 5,
277
+ use_gpu: bool = True,
278
+ outputdir: str = "model",
279
+ early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
280
+ ):
281
+ """
282
+ trains T5/MT5 model on custom dataset
283
+ Args:
284
+ train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "input_text" and "output_text"
285
+ eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "input_text" and
286
+ "output_text"
287
+ source_max_token_len (int, optional): max token length of source text. Defaults to 512.
288
+ target_max_token_len (int, optional): max token length of target text. Defaults to 512.
289
+ batch_size (int, optional): batch size. Defaults to 8.
290
+ max_epochs (int, optional): max number of epochs. Defaults to 5.
291
+ use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
292
+ outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
293
+ early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training,
294
+ if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping.
295
+ Defaults to 0 (disabled)
296
+ """
297
+ self.target_max_token_len = target_max_token_len
298
+ self.data_module = PLDataModule(
299
+ train_df,
300
+ eval_df,
301
+ self.tokenizer,
302
+ batch_size=batch_size,
303
+ source_max_token_len=source_max_token_len,
304
+ target_max_token_len=target_max_token_len,
305
+ )
306
+
307
+ self.T5Model = LightningModel(
308
+ tokenizer=self.tokenizer, model=self.model, output=outputdir
309
+ )
310
+
311
+ # checkpoint_callback = ModelCheckpoint(
312
+ # dirpath="checkpoints",
313
+ # filename="best-checkpoint-{epoch}-{train_loss:.2f}",
314
+ # save_top_k=-1,
315
+ # verbose=True,
316
+ # monitor="train_loss",
317
+ # mode="min",
318
+ # )
319
+
320
+ logger = MLFlowLogger(experiment_name="Summarization")
321
+
322
+ early_stop_callback = (
323
+ [
324
+ EarlyStopping(
325
+ monitor="val_loss",
326
+ min_delta=0.00,
327
+ patience=early_stopping_patience_epochs,
328
+ verbose=True,
329
+ mode="min",
330
+ )
331
+ ]
332
+ if early_stopping_patience_epochs > 0
333
+ else None
334
+ )
335
+
336
+ gpus = 1 if use_gpu else 0
337
+
338
+ trainer = pl.Trainer(
339
+ logger=logger,
340
+ callbacks=early_stop_callback,
341
+ max_epochs=max_epochs,
342
+ gpus=gpus,
343
+ progress_bar_refresh_rate=5,
344
+ )
345
+
346
+ trainer.fit(self.T5Model, self.data_module)
347
+
348
+ def load_model(
349
+ self, model_dir: str = "model", use_gpu: bool = False
350
+ ):
351
+ """
352
+ loads a checkpoint for inferencing/prediction
353
+ Args:
354
+ model_type (str, optional): "t5" or "mt5". Defaults to "t5".
355
+ model_dir (str, optional): path to model directory. Defaults to "outputs".
356
+ use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
357
+ """
358
+ self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
359
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
360
+
361
+ if use_gpu:
362
+ if torch.cuda.is_available():
363
+ self.device = torch.device("cuda")
364
+ else:
365
+ raise Exception("exception ---> no gpu found. set use_gpu=False, to use CPU")
366
+ else:
367
+ self.device = torch.device("cpu")
368
+
369
+ self.model = self.model.to(self.device)
370
+
371
+ def save_model(
372
+ self,
373
+ model_dir="model"
374
+ ):
375
+ """
376
+ Save model to dir
377
+ :param model_dir:
378
+ :return: model is saved
379
+ """
380
+ path = f"{model_dir}"
381
+ self.tokenizer.save_pretrained(path)
382
+ self.model.save_pretrained(path)
383
+
384
+ def predict(
385
+ self,
386
+ source_text: str,
387
+ max_length: int = 512,
388
+ num_return_sequences: int = 1,
389
+ num_beams: int = 2,
390
+ top_k: int = 50,
391
+ top_p: float = 0.95,
392
+ do_sample: bool = True,
393
+ repetition_penalty: float = 2.5,
394
+ length_penalty: float = 1.0,
395
+ early_stopping: bool = True,
396
+ skip_special_tokens: bool = True,
397
+ clean_up_tokenization_spaces: bool = True,
398
+ ):
399
+ """
400
+ generates prediction for T5/MT5 model
401
+ Args:
402
+ source_text (str): any text for generating predictions
403
+ max_length (int, optional): max token length of prediction. Defaults to 512.
404
+ num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
405
+ num_beams (int, optional): number of beams. Defaults to 2.
406
+ top_k (int, optional): Defaults to 50.
407
+ top_p (float, optional): Defaults to 0.95.
408
+ do_sample (bool, optional): Defaults to True.
409
+ repetition_penalty (float, optional): Defaults to 2.5.
410
+ length_penalty (float, optional): Defaults to 1.0.
411
+ early_stopping (bool, optional): Defaults to True.
412
+ skip_special_tokens (bool, optional): Defaults to True.
413
+ clean_up_tokenization_spaces (bool, optional): Defaults to True.
414
+ Returns:
415
+ list[str]: returns predictions
416
+ """
417
+ input_ids = self.tokenizer.encode(
418
+ source_text, return_tensors="pt", add_special_tokens=True
419
+ )
420
+
421
+ input_ids = input_ids.to(self.device)
422
+ generated_ids = self.model.generate(
423
+ input_ids=input_ids,
424
+ num_beams=num_beams,
425
+ max_length=max_length,
426
+ repetition_penalty=repetition_penalty,
427
+ length_penalty=length_penalty,
428
+ early_stopping=early_stopping,
429
+ top_p=top_p,
430
+ top_k=top_k,
431
+ num_return_sequences=num_return_sequences,
432
+ )
433
+ preds = [
434
+ self.tokenizer.decode(
435
+ g,
436
+ skip_special_tokens=skip_special_tokens,
437
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
438
+ )
439
+ for g in generated_ids
440
+ ]
441
+ return preds