Spaces:
Configuration error
Configuration error
| # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
| # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
| # International Conference on Computer Vision (ICCV), 2023 | |
| import json | |
| import numpy as np | |
| import torch.nn as nn | |
| from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer | |
| __all__ = ["Scheduler", "RunConfig"] | |
| class Scheduler: | |
| PROGRESS = 0 | |
| class RunConfig: | |
| n_epochs: int | |
| init_lr: float | |
| warmup_epochs: int | |
| warmup_lr: float | |
| lr_schedule_name: str | |
| lr_schedule_param: dict | |
| optimizer_name: str | |
| optimizer_params: dict | |
| weight_decay: float | |
| no_wd_keys: list | |
| grad_clip: float # allow none to turn off grad clipping | |
| reset_bn: bool | |
| reset_bn_size: int | |
| reset_bn_batch_size: int | |
| eval_image_size: list # allow none to use image_size in data_provider | |
| def none_allowed(self): | |
| return ["grad_clip", "eval_image_size"] | |
| def __init__(self, **kwargs): # arguments must be passed as kwargs | |
| for k, val in kwargs.items(): | |
| setattr(self, k, val) | |
| # check that all relevant configs are there | |
| annotations = {} | |
| for clas in type(self).mro(): | |
| if hasattr(clas, "__annotations__"): | |
| annotations.update(clas.__annotations__) | |
| for k, k_type in annotations.items(): | |
| assert hasattr( | |
| self, k | |
| ), f"Key {k} with type {k_type} required for initialization." | |
| attr = getattr(self, k) | |
| if k in self.none_allowed: | |
| k_type = (k_type, type(None)) | |
| assert isinstance( | |
| attr, k_type | |
| ), f"Key {k} must be type {k_type}, provided={attr}." | |
| self.global_step = 0 | |
| self.batch_per_epoch = 1 | |
| def build_optimizer(self, network: nn.Module) -> tuple[any, any]: | |
| r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler""" | |
| param_dict = {} | |
| for name, param in network.named_parameters(): | |
| if param.requires_grad: | |
| opt_config = [self.weight_decay, self.init_lr] | |
| if self.no_wd_keys is not None and len(self.no_wd_keys) > 0: | |
| if np.any([key in name for key in self.no_wd_keys]): | |
| opt_config[0] = 0 | |
| opt_key = json.dumps(opt_config) | |
| param_dict[opt_key] = param_dict.get(opt_key, []) + [param] | |
| net_params = [] | |
| for opt_key, param_list in param_dict.items(): | |
| wd, lr = json.loads(opt_key) | |
| net_params.append({"params": param_list, "weight_decay": wd, "lr": lr}) | |
| optimizer = build_optimizer( | |
| net_params, self.optimizer_name, self.optimizer_params, self.init_lr | |
| ) | |
| # build lr scheduler | |
| if self.lr_schedule_name == "cosine": | |
| decay_steps = [] | |
| for epoch in self.lr_schedule_param.get("step", []): | |
| decay_steps.append(epoch * self.batch_per_epoch) | |
| decay_steps.append(self.n_epochs * self.batch_per_epoch) | |
| decay_steps.sort() | |
| lr_scheduler = CosineLRwithWarmup( | |
| optimizer, | |
| self.warmup_epochs * self.batch_per_epoch, | |
| self.warmup_lr, | |
| decay_steps, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| return optimizer, lr_scheduler | |
| def update_global_step(self, epoch, batch_id=0) -> None: | |
| self.global_step = epoch * self.batch_per_epoch + batch_id | |
| Scheduler.PROGRESS = self.progress | |
| def progress(self) -> float: | |
| warmup_steps = self.warmup_epochs * self.batch_per_epoch | |
| steps = max(0, self.global_step - warmup_steps) | |
| return steps / (self.n_epochs * self.batch_per_epoch) | |
| def step(self) -> None: | |
| self.global_step += 1 | |
| Scheduler.PROGRESS = self.progress | |
| def get_remaining_epoch(self, epoch, post=True) -> int: | |
| return self.n_epochs + self.warmup_epochs - epoch - int(post) | |
| def epoch_format(self, epoch: int) -> str: | |
| epoch_format = f"%.{len(str(self.n_epochs))}d" | |
| epoch_format = f"[{epoch_format}/{epoch_format}]" | |
| epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs) | |
| return epoch_format | |