File size: 8,199 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import time
from builtins import print
import sys
import os
import torch
import argparse
import json
import pytorch_lightning as pl
from transformers import MT5Config, MT5Tokenizer
from pytorch_lightning import Trainer, loggers
from transformers import MT5ForConditionalGeneration
from pytorch_lightning.callbacks import LearningRateMonitor
# os.environ["CUDA_VISIBLE_DEVICES"] = '3'
class MT5PretrainModel(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_args):
parser = parent_args.add_argument_group('BaseModel')
parser.add_argument('--keep_tokens_path', default=None, type=str)
return parent_args
def __init__(self, args):
super().__init__()
self.save_hyperparameters(args)
if args.tokenizer_type == 't5_tokenizer':
if args.new_vocab_path is not None:
# 用于从mt5继续训练,此时只保留中英文词表,spm采用新模型
assert args.keep_tokens_path is not None
keep_tokens = json.load(open(args.keep_tokens_path))
self.model = MT5ForConditionalGeneration.from_pretrained(
args.pretrained_model_path)
new_config = self.model.config
new_config.vocab_size = len(keep_tokens)
print('vocab_size:', new_config.vocab_size)
new_state_dict = self.model.state_dict()
select_index = torch.tensor(keep_tokens)
new_state_dict['encoder.embed_tokens.weight'] = torch.index_select(
new_state_dict['encoder.embed_tokens.weight'], dim=0, index=select_index)
new_state_dict['shared.weight'] = torch.index_select(
new_state_dict['shared.weight'], dim=0, index=select_index)
new_state_dict['decoder.embed_tokens.weight'] = torch.index_select(
new_state_dict['decoder.embed_tokens.weight'], dim=0, index=select_index)
new_state_dict['lm_head.weight'] = torch.index_select(
new_state_dict['lm_head.weight'], dim=0, index=select_index)
self.model = MT5ForConditionalGeneration.from_pretrained(
args.pretrained_model_path, config=new_config, state_dict=new_state_dict)
# self.model = MT5ForConditionalGeneration(config=new_config)
else:
# 用于继续训练
self.model = MT5ForConditionalGeneration.from_pretrained(
args.pretrained_model_path
)
else:
self.model = MT5ForConditionalGeneration(
MT5Config.from_pretrained(args.pretrained_model_path)
)
def setup(self, stage) -> None:
if stage == 'fit':
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader()
# Calculate total steps
if self.trainer.max_epochs > 0:
world_size = self.trainer.world_size
tb_size = self.hparams.train_batchsize * max(1, world_size)
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
self.total_steps = (len(train_loader.dataset) *
self.trainer.max_epochs // tb_size) // ab_size
else:
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches
print('Total steps: {}' .format(self.total_steps))
def configure_optimizers(self):
from fengshen.models.model_utils import configure_optimizers
return configure_optimizers(self)
def training_step(self, batch, batch_idx):
output = self.model(
input_ids=batch['input_ids'], labels=batch['labels'])
acc = self.comput_metrix(output.logits, batch['labels'])
self.log('train_loss', output.loss, sync_dist=True)
self.log('train_acc', acc, sync_dist=True)
return output.loss
def validation_step(self, batch, batch_idx):
# print('is out of index: ', batch['input_ids'][batch['input_ids'] >= 32598])
output = self.model(
input_ids=batch['input_ids'], labels=batch['labels'])
acc = self.comput_metrix(output.logits, batch['labels'])
self.log('val_loss', output.loss, sync_dist=True)
self.log('val_acc', acc, sync_dist=True)
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,)).float()
corr = torch.eq(y_pred, y_true)
acc = torch.sum(corr.float())/y_true.shape[0]
return acc
def on_save_checkpoint(self, checkpoint) -> None:
# Save the current loop info in the mid of epoch
# if you lightning <= 1.6.0 uncomment the line below
# checkpoint['loops'] = self.trainer.checkpoint_connector._get_loops_state_dict()
if self.trainer.global_rank == 0 and self.trainer.global_step % self.hparams.every_n_train_steps == 0:
self.model.save_pretrained(os.path.join(
self.trainer.checkpoint_callback.dirpath,
'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step)))
def on_load_checkpoint(self, checkpoint) -> None:
global_step_offset = checkpoint["global_step"]
if 'global_samples' in checkpoint:
self.consumed_samples = checkpoint['global_samples']
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
def get_time_str():
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
def main():
total_parser = argparse.ArgumentParser("Pretrain Unsupervise.")
total_parser.add_argument(
'--do_eval_only', action='store_true', default=False)
total_parser.add_argument(
'--pretrained_model_path', default=None, type=str)
total_parser.add_argument(
'--new_vocab_path', default=None, type=str)
total_parser.add_argument('--max_seq_length', default=1024, type=int)
total_parser.add_argument('--ckpt_path', default=None, type=str)
sys.path.append('../../../')
from fengshen.data.t5_dataloader.t5_datasets import UnsuperviseT5DataModel
from fengshen.utils.universal_checkpoint import UniversalCheckpoint
# * Args for data preprocessing
total_parser = UnsuperviseT5DataModel.add_data_specific_args(total_parser)
# * Args for training
total_parser = Trainer.add_argparse_args(total_parser)
total_parser = UniversalCheckpoint.add_argparse_args(total_parser)
total_parser = MT5PretrainModel.add_model_specific_args(total_parser)
# * Args for base model
args = total_parser.parse_args()
print('Argument parse success.')
print('UnsuperviseT5DataModel load start {}'.format(get_time_str()))
data_model = UnsuperviseT5DataModel(args)
print('UnsuperviseT5DataModel load end {}'.format(get_time_str()))
if not args.do_eval_only:
model = MT5PretrainModel(args)
checkpoint_callback = UniversalCheckpoint(args)
lr_monitor = LearningRateMonitor(logging_interval='step')
logger = loggers.TensorBoardLogger(save_dir=os.path.join(
args.default_root_dir, 'logs/'))
trainer = Trainer.from_argparse_args(args,
logger=logger,
callbacks=[checkpoint_callback, lr_monitor]
)
trainer.fit(model, data_model, ckpt_path=args.ckpt_path)
else:
tokenizer = MT5Tokenizer.from_pretrained(args.new_vocab_path, extra_ids=0)
model = MT5PretrainModel(args=args, num_data=len(data_model.predict_dataloader()))
trainer = Trainer.from_argparse_args(args)
result = trainer.predict(model, data_model)
result = result[0]
for i in range(4):
print(tokenizer.batch_decode(result['input_ids'][i]))
print(tokenizer.batch_decode(result['predict_ids'][i]))
print(tokenizer.batch_decode(result['labels'][i]))
if __name__ == '__main__':
main()
|