Spaces:
Runtime error
Runtime error
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import math | |
| from typing import Callable, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from came_pytorch import CAME | |
| from mmcv import Config | |
| from mmcv.runner import OPTIMIZER_BUILDERS, OPTIMIZERS, DefaultOptimizerConstructor | |
| from mmcv.runner import build_optimizer as mm_build_optimizer | |
| from mmcv.utils import _BatchNorm, _InstanceNorm | |
| from torch.nn import GroupNorm, LayerNorm | |
| from torch.optim.optimizer import Optimizer | |
| from .logger import get_root_logger | |
| def auto_scale_lr(effective_bs, optimizer_cfg, rule="linear", base_batch_size=256): | |
| assert rule in ["linear", "sqrt"] | |
| logger = get_root_logger() | |
| # scale by world size | |
| if rule == "sqrt": | |
| scale_ratio = math.sqrt(effective_bs / base_batch_size) | |
| elif rule == "linear": | |
| scale_ratio = effective_bs / base_batch_size | |
| optimizer_cfg["lr"] *= scale_ratio | |
| logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') | |
| return scale_ratio | |
| class MyOptimizerConstructor(DefaultOptimizerConstructor): | |
| def add_params(self, params, module, prefix="", is_dcn_module=None): | |
| """Add all parameters of module to the params list. | |
| The parameters of the given module will be added to the list of param | |
| groups, with specific rules defined by paramwise_cfg. | |
| Args: | |
| params (list[dict]): A list of param groups, it will be modified | |
| in place. | |
| module (nn.Module): The module to be added. | |
| prefix (str): The prefix of the module | |
| """ | |
| # get param-wise options | |
| custom_keys = self.paramwise_cfg.get("custom_keys", {}) | |
| # first sort with alphabet order and then sort with reversed len of str | |
| # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) | |
| bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) | |
| bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) | |
| norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) | |
| bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) | |
| # special rules for norm layers and depth-wise conv layers | |
| is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) | |
| for name, param in module.named_parameters(recurse=False): | |
| base_lr = self.base_lr | |
| if name == "bias" and not (is_norm or is_dcn_module): | |
| base_lr *= bias_lr_mult | |
| # apply weight decay policies | |
| base_wd = self.base_wd | |
| if self.base_wd is not None: | |
| # norm decay | |
| if is_norm: | |
| base_wd *= norm_decay_mult | |
| # bias lr and decay | |
| elif name == "bias" and not is_dcn_module: | |
| # TODO: current bias_decay_mult will have affect on DCN | |
| base_wd *= bias_decay_mult | |
| param_group = {"params": [param]} | |
| if not param.requires_grad: | |
| param_group["requires_grad"] = False | |
| params.append(param_group) | |
| continue | |
| if bypass_duplicate and self._is_in(param_group, params): | |
| logger = get_root_logger() | |
| logger.warn(f"{prefix} is duplicate. It is skipped since " f"bypass_duplicate={bypass_duplicate}") | |
| continue | |
| # if the parameter match one of the custom keys, ignore other rules | |
| is_custom = False | |
| for key in custom_keys: | |
| if isinstance(key, tuple): | |
| scope, key_name = key | |
| else: | |
| scope, key_name = None, key | |
| if scope is not None and scope not in f"{prefix}": | |
| continue | |
| if key_name in f"{prefix}.{name}": | |
| is_custom = True | |
| if "lr_mult" in custom_keys[key]: | |
| # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': | |
| # param_group['lr'] = self.base_lr | |
| # else: | |
| param_group["lr"] = self.base_lr * custom_keys[key]["lr_mult"] | |
| elif "lr" not in param_group: | |
| param_group["lr"] = base_lr | |
| if self.base_wd is not None: | |
| if "decay_mult" in custom_keys[key]: | |
| param_group["weight_decay"] = self.base_wd * custom_keys[key]["decay_mult"] | |
| elif "weight_decay" not in param_group: | |
| param_group["weight_decay"] = base_wd | |
| if not is_custom: | |
| # bias_lr_mult affects all bias parameters | |
| # except for norm.bias dcn.conv_offset.bias | |
| if base_lr != self.base_lr: | |
| param_group["lr"] = base_lr | |
| if base_wd != self.base_wd: | |
| param_group["weight_decay"] = base_wd | |
| params.append(param_group) | |
| for child_name, child_mod in module.named_children(): | |
| child_prefix = f"{prefix}.{child_name}" if prefix else child_name | |
| self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module) | |
| def build_optimizer(model, optimizer_cfg): | |
| # default parameter-wise config | |
| logger = get_root_logger() | |
| if hasattr(model, "module"): | |
| model = model.module | |
| # set optimizer constructor | |
| optimizer_cfg.setdefault("constructor", "MyOptimizerConstructor") | |
| # parameter-wise setting: cancel weight decay for some specific modules | |
| custom_keys = dict() | |
| for name, module in model.named_modules(): | |
| if hasattr(module, "zero_weight_decay"): | |
| custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) | |
| paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) | |
| given_cfg = optimizer_cfg.get("paramwise_cfg") | |
| if given_cfg: | |
| paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) | |
| optimizer_cfg["paramwise_cfg"] = paramwise_cfg.cfg | |
| # build optimizer | |
| optimizer = mm_build_optimizer(model, optimizer_cfg) | |
| weight_decay_groups = dict() | |
| lr_groups = dict() | |
| for group in optimizer.param_groups: | |
| if not group.get("requires_grad", True): | |
| continue | |
| lr_groups.setdefault(group["lr"], []).append(group) | |
| weight_decay_groups.setdefault(group["weight_decay"], []).append(group) | |
| learnable_count, fix_count = 0, 0 | |
| for p in model.parameters(): | |
| if p.requires_grad: | |
| learnable_count += 1 | |
| else: | |
| fix_count += 1 | |
| fix_info = f"{learnable_count} are learnable, {fix_count} are fix" | |
| lr_info = "Lr group: " + ", ".join([f"{len(group)} params with lr {lr:.5f}" for lr, group in lr_groups.items()]) | |
| wd_info = "Weight decay group: " + ", ".join( | |
| [f"{len(group)} params with weight decay {wd}" for wd, group in weight_decay_groups.items()] | |
| ) | |
| opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." | |
| logger.info(opt_info) | |
| return optimizer | |
| class Lion(Optimizer): | |
| def __init__( | |
| self, | |
| params, | |
| lr: float = 1e-4, | |
| betas: Tuple[float, float] = (0.9, 0.99), | |
| weight_decay: float = 0.0, | |
| ): | |
| assert lr > 0.0 | |
| assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
| defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
| super().__init__(params, defaults) | |
| def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): | |
| # stepweight decay | |
| p.data.mul_(1 - lr * wd) | |
| # weight update | |
| update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() | |
| p.add_(update, alpha=-lr) | |
| # decay the momentum running average coefficient | |
| exp_avg.lerp_(grad, 1 - beta2) | |
| def exists(val): | |
| return val is not None | |
| def step(self, closure: Optional[Callable] = None): | |
| loss = None | |
| if self.exists(closure): | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| for p in filter(lambda p: self.exists(p.grad), group["params"]): | |
| grad, lr, wd, beta1, beta2, state = ( | |
| p.grad, | |
| group["lr"], | |
| group["weight_decay"], | |
| *group["betas"], | |
| self.state[p], | |
| ) | |
| # init state - exponential moving average of gradient values | |
| if len(state) == 0: | |
| state["exp_avg"] = torch.zeros_like(p) | |
| exp_avg = state["exp_avg"] | |
| self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2) | |
| return loss | |
| class CAMEWrapper(torch.optim.Optimizer): | |
| """Implements CAME algorithm. | |
| This implementation is based on: | |
| `CAME: Confidence-guided Adaptive Memory Efficient Optimization` | |
| Args: | |
| params (iterable): iterable of parameters to optimize or dicts defining | |
| parameter groups | |
| lr (float, optional): external learning rate (default: None) | |
| eps (tuple[float, float]): regularization constants for square gradient | |
| and instability respectively (default: (1e-30, 1e-16)) | |
| clip_threshold (float): threshold of root-mean-square of | |
| final gradient update (default: 1.0) | |
| betas (tuple[float, float, float]): coefficient used for computing running averages of | |
| update, square gradient and instability (default: (0.9, 0.999, 0.9999))) | |
| weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
| """ | |
| def __init__( | |
| self, | |
| params, | |
| lr=None, | |
| eps=(1e-30, 1e-16), | |
| clip_threshold=1.0, | |
| betas=(0.9, 0.999, 0.9999), | |
| weight_decay=0.0, | |
| ): | |
| assert lr > 0.0 | |
| assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
| defaults = dict( | |
| lr=lr, | |
| eps=eps, | |
| clip_threshold=clip_threshold, | |
| betas=betas, | |
| weight_decay=weight_decay, | |
| ) | |
| super().__init__(params, defaults) | |
| def supports_memory_efficient_fp16(self): | |
| return True | |
| def supports_flat_params(self): | |
| return False | |
| def _get_options(self, param_shape): | |
| if len(param_shape) == 4: # Conv layer | |
| if param_shape[2] == 1 and param_shape[3] == 1: # 1x1 conv | |
| return True, "1x1_conv" | |
| else: # 3x3 conv or others | |
| return False, "conv" | |
| elif len(param_shape) == 2: # Linear layer, exactly 2D | |
| return True, "linear" | |
| return False, "other" | |
| def _rms(self, tensor): | |
| return tensor.norm(2) / (tensor.numel() ** 0.5) | |
| def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): | |
| r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) | |
| c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | |
| return torch.mul(r_factor, c_factor) | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Args: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for group in self.param_groups: | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad.data | |
| if grad.dtype in {torch.float16, torch.bfloat16}: | |
| grad = grad.float() | |
| if grad.is_sparse: | |
| raise RuntimeError("CAME does not support sparse gradients.") | |
| state = self.state[p] | |
| grad_shape = grad.shape | |
| # factored = self._get_options(grad_shape) | |
| factored, layer_type = self._get_options(grad_shape) | |
| # State Initialization | |
| if len(state) == 0: | |
| state["step"] = 0 | |
| state["exp_avg"] = torch.zeros_like(grad) | |
| if factored: | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| # 1x1 conv and linear layers can be handled the same way | |
| state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
| state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
| state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
| state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
| else: | |
| state["exp_avg_sq"] = torch.zeros_like(grad) | |
| else: | |
| state["exp_avg_sq"] = torch.zeros_like(grad) | |
| state["RMS"] = 0 | |
| state["step"] += 1 | |
| state["RMS"] = self._rms(p.data) | |
| update = (grad**2) + group["eps"][0] | |
| if factored: | |
| exp_avg_sq_row = state["exp_avg_sq_row"] | |
| exp_avg_sq_col = state["exp_avg_sq_col"] | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| # Handle dimensions | |
| if len(grad_shape) == 4: # 1x1 conv | |
| update_reshaped = update.squeeze(-1).squeeze(-1) # Remove last two dimensions | |
| else: | |
| update_reshaped = update | |
| exp_avg_sq_row.mul_(group["betas"][1]).add_( | |
| update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] | |
| ) | |
| exp_avg_sq_col.mul_(group["betas"][1]).add_( | |
| update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] | |
| ) | |
| # Approximate calculation | |
| update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | |
| if layer_type == "1x1_conv": | |
| # Need to reshape back to 4D | |
| update = update.view(grad_shape[0], grad_shape[1], 1, 1) | |
| update.mul_(grad) | |
| else: | |
| # 3x3 conv or other cases: use standard AdamW approach | |
| exp_avg_sq = state["exp_avg_sq"] | |
| exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) | |
| update = exp_avg_sq.rsqrt().mul_(grad) | |
| update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) | |
| exp_avg = state["exp_avg"] | |
| exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) | |
| # Confidence-guided strategy | |
| # Calculation of instability | |
| res = (update - exp_avg) ** 2 + group["eps"][1] | |
| if factored: | |
| exp_avg_res_row = state["exp_avg_res_row"] | |
| exp_avg_res_col = state["exp_avg_res_col"] | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| # Handle dimensions | |
| if len(grad_shape) == 4: # 1x1 conv | |
| res_reshaped = res.squeeze(-1).squeeze(-1) # Remove last two dimensions | |
| else: | |
| res_reshaped = res | |
| # Update residual statistics | |
| exp_avg_res_row.mul_(group["betas"][2]).add_( | |
| res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] | |
| ) | |
| exp_avg_res_col.mul_(group["betas"][2]).add_( | |
| res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] | |
| ) | |
| # Approximate calculation | |
| res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) | |
| if layer_type == "1x1_conv": | |
| # 需要reshape回4D | |
| res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) | |
| update = res_approx.mul_(exp_avg) | |
| else: | |
| update = exp_avg.clone() | |
| if group["weight_decay"] != 0: | |
| p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) | |
| update.mul_(group["lr"]) | |
| p.data.add_(-update) | |
| return loss | |
| class CAME8BitWrapper(torch.optim.Optimizer): | |
| """Implements 8bit-CAME algorithm. | |
| Args: | |
| params (iterable): parameters to optimize or dicts defining parameter groups | |
| lr (float, optional): external learning rate (default: None) | |
| eps (tuple[float, float]): regularization constants for square gradient | |
| and instability respectively (default: (1e-30, 1e-16)) | |
| clip_threshold (float): threshold of root-mean-square of | |
| final gradient update (default: 1.0) | |
| betas (tuple[float, float, float]): coefficient used for computing running averages of | |
| update, square gradient and instability (default: (0.9, 0.999, 0.9999))) | |
| weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
| block_size (int): quantization block size, larger memory efficiency, but may reduce accuracy | |
| min_8bit_size (int): minimum parameter size for using 8bit quantization, only layers larger than this value will be quantized | |
| Note: | |
| 1. Only use 8bit quantization for large Linear layers and 1x1 Conv layers | |
| 2. Keep all statistics (exp_avg_sq_row, etc.) in 32bit to ensure stability | |
| 3. Use simple min-max quantization strategy, quantize each block separately | |
| """ | |
| def __init__( | |
| self, | |
| params, | |
| lr=None, | |
| eps=(1e-30, 1e-16), | |
| clip_threshold=1.0, | |
| betas=(0.9, 0.999, 0.9999), | |
| weight_decay=0.0, | |
| block_size=2048, | |
| min_8bit_size=16384, | |
| ): | |
| assert lr > 0.0 | |
| assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
| logger = get_root_logger() | |
| logger.info(f"Initializing CAME8bit with block_size={block_size}, min_8bit_size={min_8bit_size}") | |
| defaults = dict( | |
| lr=lr, | |
| eps=eps, | |
| clip_threshold=clip_threshold, | |
| betas=betas, | |
| weight_decay=weight_decay, | |
| block_size=block_size, | |
| min_8bit_size=min_8bit_size, | |
| ) | |
| super().__init__(params, defaults) | |
| def print_layer_info(self, param_shape, use_8bit): | |
| """Print layer information, including parameter size and whether 8bit quantization is used | |
| Args: | |
| param_shape (tuple): parameter shape | |
| use_8bit (bool): whether 8bit quantization is used | |
| """ | |
| size = np.prod(param_shape) | |
| layer_type = "unknown" | |
| if len(param_shape) == 1: | |
| layer_type = "1D Layer" | |
| elif len(param_shape) == 2: | |
| layer_type = "Linear" | |
| elif len(param_shape) == 4: | |
| if param_shape[2] == 1 and param_shape[3] == 1: | |
| layer_type = "1x1 Conv" | |
| else: | |
| layer_type = "Conv" | |
| status = "8bit" if use_8bit else "32bit" | |
| print(f"{layer_type} layer with shape {param_shape}: {size:,} params -> using {status}") | |
| def _should_use_8bit(self, param_shape): | |
| """Determine if a parameter should be quantized to 8bit | |
| Rules: | |
| 1. linear layers: parameter size > min_8bit_size | |
| 2. 1x1 conv layers: parameter size > min_8bit_size | |
| 3. other layers: use 32bit | |
| """ | |
| if len(param_shape) == 2: # linear layer | |
| return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] | |
| elif len(param_shape) == 4 and param_shape[2] == 1 and param_shape[3] == 1: | |
| return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] | |
| return False # other layers are not quantized | |
| def _quantize_state(self, state_tensor, block_size=2048): | |
| """Quantize a state tensor to 8bit | |
| Args: | |
| state_tensor: tensor to be quantized | |
| block_size: quantization block size | |
| Returns: | |
| list of quantized data blocks, each block contains: | |
| - data: uint8 data | |
| - scale: quantization scale | |
| - min: minimum value | |
| """ | |
| if state_tensor.numel() <= 1: | |
| return state_tensor | |
| quantized_chunks = [] | |
| for chunk in state_tensor.split(block_size): | |
| # Calculate quantization parameters | |
| chunk_min = chunk.min() | |
| chunk_max = chunk.max() | |
| scale = (chunk_max - chunk_min) / 255 | |
| # Quantize to 0-255 range | |
| quantized_chunk = ((chunk - chunk_min) / scale).round().byte() | |
| quantized_chunks.append({"data": quantized_chunk, "scale": scale, "min": chunk_min}) | |
| return quantized_chunks | |
| def _dequantize_state(self, quantized_chunks): | |
| """Dequantize 8bit quantized data to 32bit float | |
| Args: | |
| quantized_chunks: list of quantized data blocks | |
| Returns: | |
| dequantized 32bit float tensor | |
| """ | |
| if not isinstance(quantized_chunks, list): | |
| return quantized_chunks | |
| chunks = [] | |
| for chunk_dict in quantized_chunks: | |
| # Dequantize: value = data * scale + min | |
| chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] | |
| chunks.append(chunk) | |
| return torch.cat(chunks) | |
| def _dequantize_state_first_step(self, quantized_chunks): | |
| """Efficient dequantization for the first step""" | |
| if not isinstance(quantized_chunks, list): | |
| return quantized_chunks | |
| # 1. Dequantize all chunks to CPU | |
| dequantized_chunks = [] | |
| for chunk_dict in quantized_chunks: | |
| chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] | |
| dequantized_chunks.append(chunk) | |
| del chunk_dict["data"] | |
| torch.cuda.empty_cache() | |
| # 2. Concatenate all chunks | |
| result = torch.cat(dequantized_chunks) | |
| del dequantized_chunks | |
| torch.cuda.empty_cache() | |
| return result | |
| def _get_options(self, param_shape): | |
| if len(param_shape) == 4: | |
| if param_shape[2] == 1 and param_shape[3] == 1: | |
| return True, "1x1_conv" | |
| else: | |
| return False, "conv" | |
| elif len(param_shape) == 2: | |
| return True, "linear" | |
| return False, "other" | |
| def _rms(self, tensor): | |
| return tensor.norm(2) / (tensor.numel() ** 0.5) | |
| def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): | |
| r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) | |
| c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | |
| return torch.mul(r_factor, c_factor) | |
| def step(self, closure=None): | |
| """Perform a single optimization step | |
| Main steps: | |
| 1. Determine if 8bit quantization is needed | |
| 2. Update first and second moment estimates | |
| 3. Compute update step | |
| 4. Apply confidence-guided strategy | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for group in self.param_groups: | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad.data | |
| if grad.dtype in {torch.float16, torch.bfloat16}: | |
| grad = grad.float() | |
| if grad.is_sparse: | |
| raise RuntimeError("CAME8bit does not support sparse gradients.") | |
| state = self.state[p] | |
| grad_shape = grad.shape | |
| factored, layer_type = self._get_options(grad_shape) | |
| # Determine if 8bit quantization is used | |
| use_8bit = self._should_use_8bit(grad_shape) | |
| # State Initialization | |
| if len(state) == 0: | |
| self.print_layer_info(grad_shape, use_8bit) | |
| state["step"] = 0 | |
| # Only use 8bit quantization for large matrices | |
| if use_8bit: | |
| state["exp_avg"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
| else: | |
| state["exp_avg"] = torch.zeros_like(grad) | |
| if factored: | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| # Keep row and column statistics in 32bit | |
| state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
| state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
| state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
| state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
| else: | |
| if use_8bit: | |
| state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
| else: | |
| state["exp_avg_sq"] = torch.zeros_like(grad) | |
| else: | |
| if use_8bit: | |
| state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
| else: | |
| state["exp_avg_sq"] = torch.zeros_like(grad) | |
| state["RMS"] = 0 | |
| state["step"] += 1 | |
| state["RMS"] = self._rms(p.data) | |
| exp_avg = self._dequantize_state(state["exp_avg"]) if use_8bit else state["exp_avg"] | |
| update = (grad**2) + group["eps"][0] | |
| if factored: | |
| exp_avg_sq_row = state["exp_avg_sq_row"] # 32bit | |
| exp_avg_sq_col = state["exp_avg_sq_col"] # 32bit | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| if len(grad_shape) == 4: | |
| update_reshaped = update.squeeze(-1).squeeze(-1) | |
| else: | |
| update_reshaped = update | |
| # Update row and column statistics | |
| exp_avg_sq_row.mul_(group["betas"][1]).add_( | |
| update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] | |
| ) | |
| exp_avg_sq_col.mul_(group["betas"][1]).add_( | |
| update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] | |
| ) | |
| update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | |
| if layer_type == "1x1_conv": | |
| update = update.view(grad_shape[0], grad_shape[1], 1, 1) | |
| update.mul_(grad) | |
| else: | |
| exp_avg_sq = self._dequantize_state(state["exp_avg_sq"]) if use_8bit else state["exp_avg_sq"] | |
| exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) | |
| if use_8bit: | |
| state["exp_avg_sq"] = self._quantize_state(exp_avg_sq, group["block_size"]) | |
| else: | |
| state["exp_avg_sq"] = exp_avg_sq | |
| update = exp_avg_sq.rsqrt().mul_(grad) | |
| # Gradient clipping | |
| update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) | |
| # Update first moment | |
| exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) | |
| # Re-quantize (if needed) | |
| if use_8bit: | |
| state["exp_avg"] = self._quantize_state(exp_avg, group["block_size"]) | |
| else: | |
| state["exp_avg"] = exp_avg | |
| # Confidence-guided strategy | |
| res = (update - exp_avg) ** 2 + group["eps"][1] | |
| if factored: | |
| exp_avg_res_row = state["exp_avg_res_row"] # 32bit | |
| exp_avg_res_col = state["exp_avg_res_col"] # 32bit | |
| if layer_type == "1x1_conv" or layer_type == "linear": | |
| if len(grad_shape) == 4: | |
| res_reshaped = res.squeeze(-1).squeeze(-1) | |
| else: | |
| res_reshaped = res | |
| # Update residual statistics | |
| exp_avg_res_row.mul_(group["betas"][2]).add_( | |
| res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] | |
| ) | |
| exp_avg_res_col.mul_(group["betas"][2]).add_( | |
| res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] | |
| ) | |
| res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) | |
| if layer_type == "1x1_conv": | |
| res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) | |
| update = res_approx.mul_(exp_avg) | |
| else: | |
| update = exp_avg.clone() | |
| # Weight decay | |
| if group["weight_decay"] != 0: | |
| p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) | |
| # Apply update | |
| update.mul_(group["lr"]) | |
| p.data.add_(-update) | |
| return loss | |
| def load_state_dict(self, state_dict): | |
| """Load state dict and convert relevant states to 8bit""" | |
| super().load_state_dict(state_dict) | |
| for state in self.state.values(): | |
| for key in [ | |
| "exp_avg", | |
| "exp_avg_sq", | |
| "exp_avg_sq_row", | |
| "exp_avg_sq_col", | |
| "exp_avg_res_row", | |
| "exp_avg_res_col", | |
| ]: | |
| if key in state: | |
| if isinstance(state[key], list): | |
| state[key] = [ | |
| { | |
| "data": exp["data"].byte(), # Convert data to 8bit directly | |
| "scale": exp["scale"], # Keep scale unchanged | |
| "min": exp["min"], # Keep min unchanged | |
| } | |
| for exp in state[key] | |
| ] | |
| elif isinstance(state[key], torch.Tensor): | |
| # If tensor, keep as 32bit | |
| state[key] = state[key].float() # Ensure 32bit | |
| del state_dict | |
| torch.cuda.empty_cache() | |