Spaces:
Running
on
T4
Running
on
T4
import torch | |
import abc | |
import os | |
import copy | |
import pytorch_lightning as pl | |
from utils.lr_scheduler import * | |
from torch import distributed as dist | |
class AbstractModel(pl.LightningModule): | |
def __init__(self, | |
lr_scheduler_kwargs: dict = None, | |
optimizer_kwargs: dict = None, | |
save_path: str = None, | |
from_checkpoint: str = None, | |
load_prev_scheduler: bool = False, | |
save_weights_only: bool = True,): | |
""" | |
Args: | |
lr_scheduler: Kwargs for lr_scheduler | |
optimizer_kwargs: Kwargs for optimizer_kwargs | |
save_path: Save trained model | |
from_checkpoint: Load model from checkpoint | |
load_prev_scheduler: Whether load previous scheduler from checkpoint | |
load_strict: Whether load model strictly | |
save_weights_only: Whether save only weights or also optimizer and lr_scheduler | |
""" | |
super().__init__() | |
self.initialize_model() | |
self.metrics = {} | |
for stage in ["train", "valid", "test"]: | |
stage_metrics = self.initialize_metrics(stage) | |
# Rigister metrics as attributes | |
for metric_name, metric in stage_metrics.items(): | |
setattr(self, metric_name, metric) | |
self.metrics[stage] = stage_metrics | |
if lr_scheduler_kwargs is None: | |
# Default lr_scheduler | |
self.lr_scheduler_kwargs = { | |
"class": "ConstantLRScheduler", | |
"init_lr": 0, | |
} | |
print("No lr_scheduler_kwargs provided. The default learning rate is 0.") | |
else: | |
self.lr_scheduler_kwargs = lr_scheduler_kwargs | |
if optimizer_kwargs is None: | |
# Default optimizer | |
self.optimizer_kwargs = { | |
"class": "AdamW", | |
"betas": (0.9, 0.98), | |
"weight_decay": 0.01, | |
} | |
print("No optimizer_kwargs provided. The default optimizer is AdamW.") | |
else: | |
self.optimizer_kwargs = optimizer_kwargs | |
self.init_optimizers() | |
self.save_path = save_path | |
self.save_weights_only = save_weights_only | |
# temp_step is used for accumulating gradients | |
self.temp_step = 0 | |
self.step = 0 | |
self.epoch = 0 | |
self.load_prev_scheduler = load_prev_scheduler | |
self.from_checkpoint = from_checkpoint | |
if from_checkpoint: | |
self.load_checkpoint(from_checkpoint) | |
def initialize_model(self) -> None: | |
""" | |
All model initialization should be done here | |
Note that the whole model must be named as "self.model" for model saving and loading | |
""" | |
raise NotImplementedError | |
def forward(self, *args, **kwargs): | |
""" | |
Forward propagation | |
""" | |
raise NotImplementedError | |
def initialize_metrics(self, stage: str) -> dict: | |
""" | |
Initialize metrics for each stage | |
Args: | |
stage: "train", "valid" or "test" | |
Returns: | |
A dictionary of metrics for the stage. Keys are metric names and values are metric objects | |
""" | |
raise NotImplementedError | |
def loss_func(self, stage: str, outputs, labels) -> torch.Tensor: | |
""" | |
Args: | |
stage: "train", "valid" or "test" | |
outputs: model outputs for calculating loss | |
labels: labels for calculating loss | |
Returns: | |
loss | |
""" | |
raise NotImplementedError | |
def load_weights(model, weights): | |
model_dict = model.state_dict() | |
unused_params = [] | |
missed_params = list(model_dict.keys()) | |
for k, v in weights.items(): | |
if k in model_dict.keys(): | |
model_dict[k] = v | |
missed_params.remove(k) | |
else: | |
unused_params.append(k) | |
if len(missed_params) > 0: | |
print(f"\033[31mSome weights of {type(model).__name__} were not " | |
f"initialized from the model checkpoint: {missed_params}\033[0m") | |
if len(unused_params) > 0: | |
print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m") | |
model.load_state_dict(model_dict) | |
def optimizer_step( | |
self, | |
epoch: int, | |
batch_idx: int, | |
optimizer, | |
optimizer_closure=None, | |
) -> None: | |
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure) | |
self.temp_step += 1 | |
if self.temp_step == self.trainer.accumulate_grad_batches: | |
self.step += 1 | |
self.temp_step = 0 | |
# For pytorch-lightning 1.9.5 | |
# def optimizer_step( | |
# self, | |
# epoch: int, | |
# batch_idx: int, | |
# optimizer, | |
# optimizer_idx: int = 0, | |
# optimizer_closure=None, | |
# on_tpu: bool = False, | |
# using_native_amp: bool = False, | |
# using_lbfgs: bool = False, | |
# ) -> None: | |
# super().optimizer_step( | |
# epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs | |
# ) | |
# self.temp_step += 1 | |
# if self.temp_step == self.trainer.accumulate_grad_batches: | |
# self.step += 1 | |
# self.temp_step = 0 | |
def on_train_epoch_end(self): | |
self.epoch += 1 | |
def training_step(self, batch, batch_idx): | |
inputs, labels = batch | |
# optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98)) | |
# for _ in range(1000): | |
# outputs = self(**inputs) | |
# loss = self.loss_func('train', outputs, labels) | |
# loss.backward() | |
# optimizer.step() | |
# optimizer.zero_grad() | |
# | |
# raise | |
outputs = self(**inputs) | |
loss = self.loss_func('train', outputs, labels) | |
self.log("loss", loss, prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
inputs, labels = batch | |
outputs = self(**inputs) | |
loss = self.loss_func('valid', outputs, labels) | |
self.valid_outputs.append(loss) | |
return loss | |
def test_step(self, batch, batch_idx): | |
inputs, labels = batch | |
outputs = self(**inputs) | |
loss = self.loss_func('test', outputs, labels) | |
self.test_outputs.append(loss) | |
return loss | |
def on_train_start(self) -> None: | |
# Load previous scheduler | |
if getattr(self, "prev_schechuler", None) is not None: | |
try: | |
self.step = self.prev_schechuler["global_step"] | |
self.epoch = self.prev_schechuler["epoch"] | |
self.best_value = self.prev_schechuler["best_value"] | |
self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"]) | |
print(f"Previous training global step: {self.step}") | |
print(f"Previous training epoch: {self.epoch}") | |
print(f"Previous best value: {self.best_value}") | |
print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}") | |
# Load optimizer state | |
if hasattr(self.trainer.strategy, "deepspeed_engine"): | |
# For DeepSpeed strategy | |
try: | |
self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint) | |
except Exception as e: | |
print(e) | |
else: | |
# For DDP strategy | |
self.optimizer.load_state_dict(self.prev_schechuler["optimizer"]) | |
except Exception as e: | |
print(e) | |
raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False") | |
def on_validation_epoch_start(self) -> None: | |
setattr(self, "valid_outputs", []) | |
def on_test_epoch_start(self) -> None: | |
setattr(self, "test_outputs", []) | |
def load_checkpoint(self, from_checkpoint: str) -> None: | |
""" | |
Args: | |
from_checkpoint: Path to checkpoint. | |
""" | |
# If ``from_checkpoint`` is a directory, load the checkpoint in it | |
if os.path.isdir(from_checkpoint): | |
basename = os.path.basename(from_checkpoint) | |
from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt") | |
state_dict = torch.load(from_checkpoint, map_location=self.device) | |
self.load_weights(self.model, state_dict["model"]) | |
if self.load_prev_scheduler: | |
state_dict.pop("model") | |
self.prev_schechuler = state_dict | |
def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None: | |
""" | |
Save model to save_path | |
Args: | |
save_path: Path to save model | |
save_info: Other info to save | |
save_weights_only: Whether only save model weights | |
""" | |
dir = os.path.dirname(save_path) | |
os.makedirs(dir, exist_ok=True) | |
state_dict = {} if save_info is None else save_info | |
state_dict["model"] = self.model.state_dict() | |
# Convert model weights to fp32 | |
for k, v in state_dict["model"].items(): | |
state_dict["model"][k] = v.float() | |
if not save_weights_only: | |
state_dict["global_step"] = self.step | |
state_dict["epoch"] = self.epoch | |
state_dict["best_value"] = getattr(self, f"best_value", None) | |
state_dict["lr_scheduler"] = self.lr_schedulers().state_dict() | |
# If not using DeepSpeed, save optimizer state | |
if not hasattr(self.trainer.strategy, "deepspeed_engine"): | |
state_dict["optimizer"] = self.optimizers().optimizer.state_dict() | |
torch.save(state_dict, save_path) | |
def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None: | |
""" | |
Check whether to save model. If save_path is not None and now_value is the best, save model. | |
Args: | |
now_value: Current metric value | |
mode: "min" or "max", meaning whether the lower the better or the higher the better | |
save_info: Other info to save | |
""" | |
assert mode in ["min", "max"], "mode should be 'min' or 'max'" | |
if self.save_path is not None: | |
# In case there are variables to be included in the save path | |
save_path = eval(f"f'{self.save_path}'") | |
dir = os.path.dirname(save_path) | |
os.makedirs(dir, exist_ok=True) | |
# Check whether to save model | |
best_value = getattr(self, f"best_value", None) | |
if best_value is not None: | |
if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value: | |
return | |
setattr(self, "best_value", now_value) | |
# For DeepSpeed strategy | |
if hasattr(self.trainer.strategy, "deepspeed_engine"): | |
if not self.save_weights_only: | |
self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt") | |
# Save a complete checkpoint | |
if dist.get_rank() == 0: | |
basename = os.path.basename(save_path) | |
ckpt_path = os.path.join(save_path, f"{basename}.pt") | |
self.save_checkpoint(ckpt_path, save_info, self.save_weights_only) | |
# For normal situation | |
else: | |
if dist.get_rank() == 0: | |
self.save_checkpoint(save_path, save_info, self.save_weights_only) | |
def reset_metrics(self, stage) -> None: | |
""" | |
Reset metrics for given stage | |
Args: | |
stage: "train", "valid" or "test" | |
""" | |
for metric in self.metrics[stage].values(): | |
metric.reset() | |
def get_log_dict(self, stage: str) -> dict: | |
""" | |
Get log dict for the stage | |
Args: | |
stage: "train", "valid" or "test" | |
Returns: | |
A dictionary of metrics for the stage. Keys are metric names and values are metric values | |
""" | |
return {name: metric.compute() for name, metric in self.metrics[stage].items()} | |
def log_info(self, info: dict) -> None: | |
""" | |
Record metrics during training and testing | |
Args: | |
info: dict of metrics | |
""" | |
if getattr(self, "logger", None) is not None and dist.get_rank() == 0: | |
info["learning_rate"] = self.lr_scheduler.get_last_lr()[0] | |
info["epoch"] = self.epoch | |
self.logger.log_metrics(info, step=self.step) | |
def init_optimizers(self): | |
copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs) | |
# No decay for layer norm and bias | |
no_decay = ['LayerNorm.weight', 'bias'] | |
weight_decay = copy_optimizer_kwargs.pop("weight_decay") | |
optimizer_grouped_parameters = [ | |
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], | |
'weight_decay': weight_decay}, | |
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], | |
'weight_decay': 0.0} | |
] | |
optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}") | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, | |
lr=self.lr_scheduler_kwargs['init_lr'], | |
**copy_optimizer_kwargs) | |
tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs) | |
lr_scheduler = tmp_kwargs.pop("class") | |
self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs) | |
def configure_optimizers(self): | |
return {"optimizer": self.optimizer, | |
"lr_scheduler": {"scheduler": self.lr_scheduler, | |
"interval": "step", | |
"frequency": 1} | |
} | |