Spaces:
Configuration error
Configuration error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import os | |
import torch | |
import torch.nn as nn | |
from efficientvit.apps.data_provider import DataProvider, parse_image_size | |
from efficientvit.apps.trainer.run_config import RunConfig | |
from efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank, | |
is_master) | |
from efficientvit.models.nn.norm import reset_bn | |
from efficientvit.models.utils import is_parallel, load_state_dict_from_file | |
__all__ = ["Trainer"] | |
class Trainer: | |
def __init__(self, path: str, model: nn.Module, data_provider: DataProvider): | |
self.path = os.path.realpath(os.path.expanduser(path)) | |
self.model = model.cuda() | |
self.data_provider = data_provider | |
self.ema = None | |
self.checkpoint_path = os.path.join(self.path, "checkpoint") | |
self.logs_path = os.path.join(self.path, "logs") | |
for path in [self.path, self.checkpoint_path, self.logs_path]: | |
os.makedirs(path, exist_ok=True) | |
self.best_val = 0.0 | |
self.start_epoch = 0 | |
def network(self) -> nn.Module: | |
return self.model.module if is_parallel(self.model) else self.model | |
def eval_network(self) -> nn.Module: | |
if self.ema is None: | |
model = self.model | |
else: | |
model = self.ema.shadows | |
model = model.module if is_parallel(model) else model | |
return model | |
def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None: | |
if is_master(): | |
fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode) | |
fout.write(log_str + "\n") | |
fout.flush() | |
fout.close() | |
if print_log: | |
print(log_str) | |
def save_model( | |
self, | |
checkpoint=None, | |
only_state_dict=True, | |
epoch=0, | |
model_name=None, | |
) -> None: | |
if is_master(): | |
if checkpoint is None: | |
if only_state_dict: | |
checkpoint = {"state_dict": self.network.state_dict()} | |
else: | |
checkpoint = { | |
"state_dict": self.network.state_dict(), | |
"epoch": epoch, | |
"best_val": self.best_val, | |
"optimizer": self.optimizer.state_dict(), | |
"lr_scheduler": self.lr_scheduler.state_dict(), | |
"ema": self.ema.state_dict() if self.ema is not None else None, | |
"scaler": self.scaler.state_dict() if self.fp16 else None, | |
} | |
model_name = model_name or "checkpoint.pt" | |
latest_fname = os.path.join(self.checkpoint_path, "latest.txt") | |
model_path = os.path.join(self.checkpoint_path, model_name) | |
with open(latest_fname, "w") as _fout: | |
_fout.write(model_path + "\n") | |
torch.save(checkpoint, model_path) | |
def load_model(self, model_fname=None) -> None: | |
latest_fname = os.path.join(self.checkpoint_path, "latest.txt") | |
if model_fname is None and os.path.exists(latest_fname): | |
with open(latest_fname, "r") as fin: | |
model_fname = fin.readline() | |
if len(model_fname) > 0 and model_fname[-1] == "\n": | |
model_fname = model_fname[:-1] | |
try: | |
if model_fname is None: | |
model_fname = f"{self.checkpoint_path}/checkpoint.pt" | |
elif not os.path.exists(model_fname): | |
model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}" | |
if not os.path.exists(model_fname): | |
model_fname = f"{self.checkpoint_path}/checkpoint.pt" | |
print(f"=> loading checkpoint {model_fname}") | |
checkpoint = load_state_dict_from_file(model_fname, False) | |
except Exception: | |
self.write_log(f"fail to load checkpoint from {self.checkpoint_path}") | |
return | |
# load checkpoint | |
self.network.load_state_dict(checkpoint["state_dict"], strict=False) | |
log = [] | |
if "epoch" in checkpoint: | |
self.start_epoch = checkpoint["epoch"] + 1 | |
self.run_config.update_global_step(self.start_epoch) | |
log.append(f"epoch={self.start_epoch - 1}") | |
if "best_val" in checkpoint: | |
self.best_val = checkpoint["best_val"] | |
log.append(f"best_val={self.best_val:.2f}") | |
if "optimizer" in checkpoint: | |
self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
log.append("optimizer") | |
if "lr_scheduler" in checkpoint: | |
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
log.append("lr_scheduler") | |
if "ema" in checkpoint and self.ema is not None: | |
self.ema.load_state_dict(checkpoint["ema"]) | |
log.append("ema") | |
if "scaler" in checkpoint and self.fp16: | |
self.scaler.load_state_dict(checkpoint["scaler"]) | |
log.append("scaler") | |
self.write_log("Loaded: " + ", ".join(log)) | |
""" validate """ | |
def reset_bn( | |
self, | |
network: nn.Module or None = None, | |
subset_size: int = 16000, | |
subset_batch_size: int = 100, | |
data_loader=None, | |
progress_bar=False, | |
) -> None: | |
network = network or self.network | |
if data_loader is None: | |
data_loader = [] | |
for data in self.data_provider.build_sub_train_loader( | |
subset_size, subset_batch_size | |
): | |
if isinstance(data, list): | |
data_loader.append(data[0]) | |
elif isinstance(data, dict): | |
data_loader.append(data["data"]) | |
elif isinstance(data, torch.Tensor): | |
data_loader.append(data) | |
else: | |
raise NotImplementedError | |
network.eval() | |
reset_bn( | |
network, | |
data_loader, | |
sync=True, | |
progress_bar=progress_bar, | |
) | |
def _validate(self, model, data_loader, epoch) -> dict[str, any]: | |
raise NotImplementedError | |
def validate( | |
self, model=None, data_loader=None, is_test=True, epoch=0 | |
) -> dict[str, any]: | |
model = model or self.eval_network | |
if data_loader is None: | |
if is_test: | |
data_loader = self.data_provider.test | |
else: | |
data_loader = self.data_provider.valid | |
model.eval() | |
return self._validate(model, data_loader, epoch) | |
def multires_validate( | |
self, | |
model=None, | |
data_loader=None, | |
is_test=True, | |
epoch=0, | |
eval_image_size=None, | |
) -> dict[str, dict[str, any]]: | |
eval_image_size = eval_image_size or self.run_config.eval_image_size | |
eval_image_size = eval_image_size or self.data_provider.image_size | |
model = model or self.eval_network | |
if not isinstance(eval_image_size, list): | |
eval_image_size = [eval_image_size] | |
output_dict = {} | |
for r in eval_image_size: | |
self.data_provider.assign_active_image_size(parse_image_size(r)) | |
if self.run_config.reset_bn: | |
self.reset_bn( | |
network=model, | |
subset_size=self.run_config.reset_bn_size, | |
subset_batch_size=self.run_config.reset_bn_batch_size, | |
progress_bar=True, | |
) | |
output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch) | |
return output_dict | |
""" training """ | |
def prep_for_training( | |
self, run_config: RunConfig, ema_decay: float or None = None, fp16=False | |
) -> None: | |
self.run_config = run_config | |
self.model = nn.parallel.DistributedDataParallel( | |
self.model.cuda(), | |
device_ids=[get_dist_local_rank()], | |
static_graph=True, | |
) | |
self.run_config.global_step = 0 | |
self.run_config.batch_per_epoch = len(self.data_provider.train) | |
assert self.run_config.batch_per_epoch > 0, "Training set is empty" | |
# build optimizer | |
self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) | |
if ema_decay is not None: | |
self.ema = EMA(self.network, ema_decay) | |
# fp16 | |
self.fp16 = fp16 | |
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) | |
def sync_model(self): | |
print("Sync model") | |
self.save_model(model_name="sync.pt") | |
dist_barrier() | |
checkpoint = torch.load( | |
os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu" | |
) | |
dist_barrier() | |
if is_master(): | |
os.remove(os.path.join(self.checkpoint_path, "sync.pt")) | |
dist_barrier() | |
# load checkpoint | |
self.network.load_state_dict(checkpoint["state_dict"], strict=False) | |
if "optimizer" in checkpoint: | |
self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
if "lr_scheduler" in checkpoint: | |
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
if "ema" in checkpoint and self.ema is not None: | |
self.ema.load_state_dict(checkpoint["ema"]) | |
if "scaler" in checkpoint and self.fp16: | |
self.scaler.load_state_dict(checkpoint["scaler"]) | |
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
for key in feed_dict: | |
if isinstance(feed_dict[key], torch.Tensor): | |
feed_dict[key] = feed_dict[key].cuda() | |
return feed_dict | |
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
raise NotImplementedError | |
def after_step(self) -> None: | |
self.scaler.unscale_(self.optimizer) | |
# gradient clip | |
if self.run_config.grad_clip is not None: | |
torch.nn.utils.clip_grad_value_( | |
self.model.parameters(), self.run_config.grad_clip | |
) | |
# update | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
self.lr_scheduler.step() | |
self.run_config.step() | |
# update ema | |
if self.ema is not None: | |
self.ema.step(self.network, self.run_config.global_step) | |
def _train_one_epoch(self, epoch: int) -> dict[str, any]: | |
raise NotImplementedError | |
def train_one_epoch(self, epoch: int) -> dict[str, any]: | |
self.model.train() | |
self.data_provider.set_epoch(epoch) | |
train_info_dict = self._train_one_epoch(epoch) | |
return train_info_dict | |
def train(self) -> None: | |
raise NotImplementedError | |