# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py import os,sys now_dir = os.getcwd() sys.path.append(now_dir) from typing import Dict import torch from pytorch_lightning import LightningModule from AR.models.t2s_model import Text2SemanticDecoder from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): def __init__(self, config, output_dir,is_train=True): super().__init__() self.config = config self.top_k = 3 self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) pretrained_s1=config.get("pretrained_s1") if(pretrained_s1 and is_train): # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"])) if is_train: self.automatic_optimization = False self.save_hyperparameters() self.eval_dir = output_dir / 'eval' self.eval_dir.mkdir(parents=True, exist_ok=True) def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() scheduler = self.lr_schedulers() loss, acc = self.model.forward( batch['phoneme_ids'], batch['phoneme_ids_len'], batch['semantic_ids'], batch['semantic_ids_len'], batch['bert_feature']) self.manual_backward(loss) if batch_idx > 0 and batch_idx % 4 == 0: opt.step() opt.zero_grad() scheduler.step() self.log( "total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) self.log( "lr", scheduler.get_last_lr()[0], on_epoch=True, prog_bar=True, sync_dist=True) self.log( f"top_{self.top_k}_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) def validation_step(self, batch: Dict, batch_idx: int):return # # get loss # loss, acc = self.model.forward( # batch['phoneme_ids'], batch['phoneme_ids_len'], # batch['semantic_ids'], batch['semantic_ids_len'], # batch['bert_feature'] # ) # # self.log( # "val_total_loss", # loss, # on_step=True, # on_epoch=True, # prog_bar=True, # sync_dist=True) # self.log( # f"val_top_{self.top_k}_acc", # acc, # on_step=True, # on_epoch=True, # prog_bar=True, # sync_dist=True) # # # get infer output # semantic_len = batch['semantic_ids'].size(1) # prompt_len = min(int(semantic_len * 0.5), 150) # prompt = batch['semantic_ids'][:, :prompt_len] # pred_semantic = self.model.infer(batch['phoneme_ids'], # batch['phoneme_ids_len'], prompt, # batch['bert_feature'] # ) # save_name = f'semantic_toks_{batch_idx}.pt' # save_path = os.path.join(self.eval_dir, save_name) # torch.save(pred_semantic.detach().cpu(), save_path) def configure_optimizers(self): model_parameters = self.model.parameters() parameters_names = [] parameters_names.append([ name_param_pair[0] for name_param_pair in self.model.named_parameters() ]) lm_opt = ScaledAdam( model_parameters, lr=0.01, betas=(0.9, 0.95), clipping_scale=2.0, parameters_names=parameters_names, show_dominant_parameters=False, clipping_update_period=1000, ) return { "optimizer": lm_opt, "lr_scheduler": { "scheduler": WarmupCosineLRSchedule( lm_opt, init_lr=self.config['optimizer']['lr_init'], peak_lr=self.config['optimizer']['lr'], end_lr=self.config['optimizer']['lr_end'], warmup_steps=self.config['optimizer']['warmup_steps'], total_steps=self.config['optimizer']['decay_steps']) } }