Spaces:
Configuration error
Configuration error
| # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
| # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
| # International Conference on Computer Vision (ICCV), 2023 | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from src.efficientvit.apps.data_provider import DataProvider, parse_image_size | |
| from src.efficientvit.apps.trainer.run_config import RunConfig | |
| from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank, | |
| is_master) | |
| from src.efficientvit.models.nn.norm import reset_bn | |
| from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file | |
| __all__ = ["Trainer"] | |
| class Trainer: | |
| def __init__(self, path: str, model: nn.Module, data_provider: DataProvider): | |
| self.path = os.path.realpath(os.path.expanduser(path)) | |
| self.model = model.cuda() | |
| self.data_provider = data_provider | |
| self.ema = None | |
| self.checkpoint_path = os.path.join(self.path, "checkpoint") | |
| self.logs_path = os.path.join(self.path, "logs") | |
| for path in [self.path, self.checkpoint_path, self.logs_path]: | |
| os.makedirs(path, exist_ok=True) | |
| self.best_val = 0.0 | |
| self.start_epoch = 0 | |
| def network(self) -> nn.Module: | |
| return self.model.module if is_parallel(self.model) else self.model | |
| def eval_network(self) -> nn.Module: | |
| if self.ema is None: | |
| model = self.model | |
| else: | |
| model = self.ema.shadows | |
| model = model.module if is_parallel(model) else model | |
| return model | |
| def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None: | |
| if is_master(): | |
| fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode) | |
| fout.write(log_str + "\n") | |
| fout.flush() | |
| fout.close() | |
| if print_log: | |
| print(log_str) | |
| def save_model( | |
| self, | |
| checkpoint=None, | |
| only_state_dict=True, | |
| epoch=0, | |
| model_name=None, | |
| ) -> None: | |
| if is_master(): | |
| if checkpoint is None: | |
| if only_state_dict: | |
| checkpoint = {"state_dict": self.network.state_dict()} | |
| else: | |
| checkpoint = { | |
| "state_dict": self.network.state_dict(), | |
| "epoch": epoch, | |
| "best_val": self.best_val, | |
| "optimizer": self.optimizer.state_dict(), | |
| "lr_scheduler": self.lr_scheduler.state_dict(), | |
| "ema": self.ema.state_dict() if self.ema is not None else None, | |
| "scaler": self.scaler.state_dict() if self.fp16 else None, | |
| } | |
| model_name = model_name or "checkpoint.pt" | |
| latest_fname = os.path.join(self.checkpoint_path, "latest.txt") | |
| model_path = os.path.join(self.checkpoint_path, model_name) | |
| with open(latest_fname, "w") as _fout: | |
| _fout.write(model_path + "\n") | |
| torch.save(checkpoint, model_path) | |
| def load_model(self, model_fname=None) -> None: | |
| latest_fname = os.path.join(self.checkpoint_path, "latest.txt") | |
| if model_fname is None and os.path.exists(latest_fname): | |
| with open(latest_fname, "r") as fin: | |
| model_fname = fin.readline() | |
| if len(model_fname) > 0 and model_fname[-1] == "\n": | |
| model_fname = model_fname[:-1] | |
| try: | |
| if model_fname is None: | |
| model_fname = f"{self.checkpoint_path}/checkpoint.pt" | |
| elif not os.path.exists(model_fname): | |
| model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}" | |
| if not os.path.exists(model_fname): | |
| model_fname = f"{self.checkpoint_path}/checkpoint.pt" | |
| print(f"=> loading checkpoint {model_fname}") | |
| checkpoint = load_state_dict_from_file(model_fname, False) | |
| except Exception: | |
| self.write_log(f"fail to load checkpoint from {self.checkpoint_path}") | |
| return | |
| # load checkpoint | |
| self.network.load_state_dict(checkpoint["state_dict"], strict=False) | |
| log = [] | |
| if "epoch" in checkpoint: | |
| self.start_epoch = checkpoint["epoch"] + 1 | |
| self.run_config.update_global_step(self.start_epoch) | |
| log.append(f"epoch={self.start_epoch - 1}") | |
| if "best_val" in checkpoint: | |
| self.best_val = checkpoint["best_val"] | |
| log.append(f"best_val={self.best_val:.2f}") | |
| if "optimizer" in checkpoint: | |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
| log.append("optimizer") | |
| if "lr_scheduler" in checkpoint: | |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| log.append("lr_scheduler") | |
| if "ema" in checkpoint and self.ema is not None: | |
| self.ema.load_state_dict(checkpoint["ema"]) | |
| log.append("ema") | |
| if "scaler" in checkpoint and self.fp16: | |
| self.scaler.load_state_dict(checkpoint["scaler"]) | |
| log.append("scaler") | |
| self.write_log("Loaded: " + ", ".join(log)) | |
| """ validate """ | |
| def reset_bn( | |
| self, | |
| network: nn.Module or None = None, | |
| subset_size: int = 16000, | |
| subset_batch_size: int = 100, | |
| data_loader=None, | |
| progress_bar=False, | |
| ) -> None: | |
| network = network or self.network | |
| if data_loader is None: | |
| data_loader = [] | |
| for data in self.data_provider.build_sub_train_loader( | |
| subset_size, subset_batch_size | |
| ): | |
| if isinstance(data, list): | |
| data_loader.append(data[0]) | |
| elif isinstance(data, dict): | |
| data_loader.append(data["data"]) | |
| elif isinstance(data, torch.Tensor): | |
| data_loader.append(data) | |
| else: | |
| raise NotImplementedError | |
| network.eval() | |
| reset_bn( | |
| network, | |
| data_loader, | |
| sync=True, | |
| progress_bar=progress_bar, | |
| ) | |
| def _validate(self, model, data_loader, epoch) -> dict[str, any]: | |
| raise NotImplementedError | |
| def validate( | |
| self, model=None, data_loader=None, is_test=True, epoch=0 | |
| ) -> dict[str, any]: | |
| model = model or self.eval_network | |
| if data_loader is None: | |
| if is_test: | |
| data_loader = self.data_provider.test | |
| else: | |
| data_loader = self.data_provider.valid | |
| model.eval() | |
| return self._validate(model, data_loader, epoch) | |
| def multires_validate( | |
| self, | |
| model=None, | |
| data_loader=None, | |
| is_test=True, | |
| epoch=0, | |
| eval_image_size=None, | |
| ) -> dict[str, dict[str, any]]: | |
| eval_image_size = eval_image_size or self.run_config.eval_image_size | |
| eval_image_size = eval_image_size or self.data_provider.image_size | |
| model = model or self.eval_network | |
| if not isinstance(eval_image_size, list): | |
| eval_image_size = [eval_image_size] | |
| output_dict = {} | |
| for r in eval_image_size: | |
| self.data_provider.assign_active_image_size(parse_image_size(r)) | |
| if self.run_config.reset_bn: | |
| self.reset_bn( | |
| network=model, | |
| subset_size=self.run_config.reset_bn_size, | |
| subset_batch_size=self.run_config.reset_bn_batch_size, | |
| progress_bar=True, | |
| ) | |
| output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch) | |
| return output_dict | |
| """ training """ | |
| def prep_for_training( | |
| self, run_config: RunConfig, ema_decay: float or None = None, fp16=False | |
| ) -> None: | |
| self.run_config = run_config | |
| self.model = nn.parallel.DistributedDataParallel( | |
| self.model.cuda(), | |
| device_ids=[get_dist_local_rank()], | |
| static_graph=True, | |
| ) | |
| self.run_config.global_step = 0 | |
| self.run_config.batch_per_epoch = len(self.data_provider.train) | |
| assert self.run_config.batch_per_epoch > 0, "Training set is empty" | |
| # build optimizer | |
| self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) | |
| if ema_decay is not None: | |
| self.ema = EMA(self.network, ema_decay) | |
| # fp16 | |
| self.fp16 = fp16 | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) | |
| def sync_model(self): | |
| print("Sync model") | |
| self.save_model(model_name="sync.pt") | |
| dist_barrier() | |
| checkpoint = torch.load( | |
| os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu" | |
| ) | |
| dist_barrier() | |
| if is_master(): | |
| os.remove(os.path.join(self.checkpoint_path, "sync.pt")) | |
| dist_barrier() | |
| # load checkpoint | |
| self.network.load_state_dict(checkpoint["state_dict"], strict=False) | |
| if "optimizer" in checkpoint: | |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
| if "lr_scheduler" in checkpoint: | |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| if "ema" in checkpoint and self.ema is not None: | |
| self.ema.load_state_dict(checkpoint["ema"]) | |
| if "scaler" in checkpoint and self.fp16: | |
| self.scaler.load_state_dict(checkpoint["scaler"]) | |
| def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
| for key in feed_dict: | |
| if isinstance(feed_dict[key], torch.Tensor): | |
| feed_dict[key] = feed_dict[key].cuda() | |
| return feed_dict | |
| def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
| raise NotImplementedError | |
| def after_step(self) -> None: | |
| self.scaler.unscale_(self.optimizer) | |
| # gradient clip | |
| if self.run_config.grad_clip is not None: | |
| torch.nn.utils.clip_grad_value_( | |
| self.model.parameters(), self.run_config.grad_clip | |
| ) | |
| # update | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.lr_scheduler.step() | |
| self.run_config.step() | |
| # update ema | |
| if self.ema is not None: | |
| self.ema.step(self.network, self.run_config.global_step) | |
| def _train_one_epoch(self, epoch: int) -> dict[str, any]: | |
| raise NotImplementedError | |
| def train_one_epoch(self, epoch: int) -> dict[str, any]: | |
| self.model.train() | |
| self.data_provider.set_epoch(epoch) | |
| train_info_dict = self._train_one_epoch(epoch) | |
| return train_info_dict | |
| def train(self) -> None: | |
| raise NotImplementedError | |