Spaces:
Runtime error
Runtime error
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import math | |
from typing import Callable, Optional, Tuple | |
import numpy as np | |
import torch | |
from came_pytorch import CAME | |
from mmcv import Config | |
from mmcv.runner import OPTIMIZER_BUILDERS, OPTIMIZERS, DefaultOptimizerConstructor | |
from mmcv.runner import build_optimizer as mm_build_optimizer | |
from mmcv.utils import _BatchNorm, _InstanceNorm | |
from torch.nn import GroupNorm, LayerNorm | |
from torch.optim.optimizer import Optimizer | |
from .logger import get_root_logger | |
def auto_scale_lr(effective_bs, optimizer_cfg, rule="linear", base_batch_size=256): | |
assert rule in ["linear", "sqrt"] | |
logger = get_root_logger() | |
# scale by world size | |
if rule == "sqrt": | |
scale_ratio = math.sqrt(effective_bs / base_batch_size) | |
elif rule == "linear": | |
scale_ratio = effective_bs / base_batch_size | |
optimizer_cfg["lr"] *= scale_ratio | |
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') | |
return scale_ratio | |
class MyOptimizerConstructor(DefaultOptimizerConstructor): | |
def add_params(self, params, module, prefix="", is_dcn_module=None): | |
"""Add all parameters of module to the params list. | |
The parameters of the given module will be added to the list of param | |
groups, with specific rules defined by paramwise_cfg. | |
Args: | |
params (list[dict]): A list of param groups, it will be modified | |
in place. | |
module (nn.Module): The module to be added. | |
prefix (str): The prefix of the module | |
""" | |
# get param-wise options | |
custom_keys = self.paramwise_cfg.get("custom_keys", {}) | |
# first sort with alphabet order and then sort with reversed len of str | |
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) | |
bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) | |
bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) | |
norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) | |
bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) | |
# special rules for norm layers and depth-wise conv layers | |
is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) | |
for name, param in module.named_parameters(recurse=False): | |
base_lr = self.base_lr | |
if name == "bias" and not (is_norm or is_dcn_module): | |
base_lr *= bias_lr_mult | |
# apply weight decay policies | |
base_wd = self.base_wd | |
if self.base_wd is not None: | |
# norm decay | |
if is_norm: | |
base_wd *= norm_decay_mult | |
# bias lr and decay | |
elif name == "bias" and not is_dcn_module: | |
# TODO: current bias_decay_mult will have affect on DCN | |
base_wd *= bias_decay_mult | |
param_group = {"params": [param]} | |
if not param.requires_grad: | |
param_group["requires_grad"] = False | |
params.append(param_group) | |
continue | |
if bypass_duplicate and self._is_in(param_group, params): | |
logger = get_root_logger() | |
logger.warn(f"{prefix} is duplicate. It is skipped since " f"bypass_duplicate={bypass_duplicate}") | |
continue | |
# if the parameter match one of the custom keys, ignore other rules | |
is_custom = False | |
for key in custom_keys: | |
if isinstance(key, tuple): | |
scope, key_name = key | |
else: | |
scope, key_name = None, key | |
if scope is not None and scope not in f"{prefix}": | |
continue | |
if key_name in f"{prefix}.{name}": | |
is_custom = True | |
if "lr_mult" in custom_keys[key]: | |
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': | |
# param_group['lr'] = self.base_lr | |
# else: | |
param_group["lr"] = self.base_lr * custom_keys[key]["lr_mult"] | |
elif "lr" not in param_group: | |
param_group["lr"] = base_lr | |
if self.base_wd is not None: | |
if "decay_mult" in custom_keys[key]: | |
param_group["weight_decay"] = self.base_wd * custom_keys[key]["decay_mult"] | |
elif "weight_decay" not in param_group: | |
param_group["weight_decay"] = base_wd | |
if not is_custom: | |
# bias_lr_mult affects all bias parameters | |
# except for norm.bias dcn.conv_offset.bias | |
if base_lr != self.base_lr: | |
param_group["lr"] = base_lr | |
if base_wd != self.base_wd: | |
param_group["weight_decay"] = base_wd | |
params.append(param_group) | |
for child_name, child_mod in module.named_children(): | |
child_prefix = f"{prefix}.{child_name}" if prefix else child_name | |
self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module) | |
def build_optimizer(model, optimizer_cfg): | |
# default parameter-wise config | |
logger = get_root_logger() | |
if hasattr(model, "module"): | |
model = model.module | |
# set optimizer constructor | |
optimizer_cfg.setdefault("constructor", "MyOptimizerConstructor") | |
# parameter-wise setting: cancel weight decay for some specific modules | |
custom_keys = dict() | |
for name, module in model.named_modules(): | |
if hasattr(module, "zero_weight_decay"): | |
custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) | |
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) | |
given_cfg = optimizer_cfg.get("paramwise_cfg") | |
if given_cfg: | |
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) | |
optimizer_cfg["paramwise_cfg"] = paramwise_cfg.cfg | |
# build optimizer | |
optimizer = mm_build_optimizer(model, optimizer_cfg) | |
weight_decay_groups = dict() | |
lr_groups = dict() | |
for group in optimizer.param_groups: | |
if not group.get("requires_grad", True): | |
continue | |
lr_groups.setdefault(group["lr"], []).append(group) | |
weight_decay_groups.setdefault(group["weight_decay"], []).append(group) | |
learnable_count, fix_count = 0, 0 | |
for p in model.parameters(): | |
if p.requires_grad: | |
learnable_count += 1 | |
else: | |
fix_count += 1 | |
fix_info = f"{learnable_count} are learnable, {fix_count} are fix" | |
lr_info = "Lr group: " + ", ".join([f"{len(group)} params with lr {lr:.5f}" for lr, group in lr_groups.items()]) | |
wd_info = "Weight decay group: " + ", ".join( | |
[f"{len(group)} params with weight decay {wd}" for wd, group in weight_decay_groups.items()] | |
) | |
opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." | |
logger.info(opt_info) | |
return optimizer | |
class Lion(Optimizer): | |
def __init__( | |
self, | |
params, | |
lr: float = 1e-4, | |
betas: Tuple[float, float] = (0.9, 0.99), | |
weight_decay: float = 0.0, | |
): | |
assert lr > 0.0 | |
assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
super().__init__(params, defaults) | |
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): | |
# stepweight decay | |
p.data.mul_(1 - lr * wd) | |
# weight update | |
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() | |
p.add_(update, alpha=-lr) | |
# decay the momentum running average coefficient | |
exp_avg.lerp_(grad, 1 - beta2) | |
def exists(val): | |
return val is not None | |
def step(self, closure: Optional[Callable] = None): | |
loss = None | |
if self.exists(closure): | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
for p in filter(lambda p: self.exists(p.grad), group["params"]): | |
grad, lr, wd, beta1, beta2, state = ( | |
p.grad, | |
group["lr"], | |
group["weight_decay"], | |
*group["betas"], | |
self.state[p], | |
) | |
# init state - exponential moving average of gradient values | |
if len(state) == 0: | |
state["exp_avg"] = torch.zeros_like(p) | |
exp_avg = state["exp_avg"] | |
self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2) | |
return loss | |
class CAMEWrapper(torch.optim.Optimizer): | |
"""Implements CAME algorithm. | |
This implementation is based on: | |
`CAME: Confidence-guided Adaptive Memory Efficient Optimization` | |
Args: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups | |
lr (float, optional): external learning rate (default: None) | |
eps (tuple[float, float]): regularization constants for square gradient | |
and instability respectively (default: (1e-30, 1e-16)) | |
clip_threshold (float): threshold of root-mean-square of | |
final gradient update (default: 1.0) | |
betas (tuple[float, float, float]): coefficient used for computing running averages of | |
update, square gradient and instability (default: (0.9, 0.999, 0.9999))) | |
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
""" | |
def __init__( | |
self, | |
params, | |
lr=None, | |
eps=(1e-30, 1e-16), | |
clip_threshold=1.0, | |
betas=(0.9, 0.999, 0.9999), | |
weight_decay=0.0, | |
): | |
assert lr > 0.0 | |
assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
defaults = dict( | |
lr=lr, | |
eps=eps, | |
clip_threshold=clip_threshold, | |
betas=betas, | |
weight_decay=weight_decay, | |
) | |
super().__init__(params, defaults) | |
def supports_memory_efficient_fp16(self): | |
return True | |
def supports_flat_params(self): | |
return False | |
def _get_options(self, param_shape): | |
if len(param_shape) == 4: # Conv layer | |
if param_shape[2] == 1 and param_shape[3] == 1: # 1x1 conv | |
return True, "1x1_conv" | |
else: # 3x3 conv or others | |
return False, "conv" | |
elif len(param_shape) == 2: # Linear layer, exactly 2D | |
return True, "linear" | |
return False, "other" | |
def _rms(self, tensor): | |
return tensor.norm(2) / (tensor.numel() ** 0.5) | |
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): | |
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) | |
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | |
return torch.mul(r_factor, c_factor) | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
if grad.dtype in {torch.float16, torch.bfloat16}: | |
grad = grad.float() | |
if grad.is_sparse: | |
raise RuntimeError("CAME does not support sparse gradients.") | |
state = self.state[p] | |
grad_shape = grad.shape | |
# factored = self._get_options(grad_shape) | |
factored, layer_type = self._get_options(grad_shape) | |
# State Initialization | |
if len(state) == 0: | |
state["step"] = 0 | |
state["exp_avg"] = torch.zeros_like(grad) | |
if factored: | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
# 1x1 conv and linear layers can be handled the same way | |
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
else: | |
state["exp_avg_sq"] = torch.zeros_like(grad) | |
else: | |
state["exp_avg_sq"] = torch.zeros_like(grad) | |
state["RMS"] = 0 | |
state["step"] += 1 | |
state["RMS"] = self._rms(p.data) | |
update = (grad**2) + group["eps"][0] | |
if factored: | |
exp_avg_sq_row = state["exp_avg_sq_row"] | |
exp_avg_sq_col = state["exp_avg_sq_col"] | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
# Handle dimensions | |
if len(grad_shape) == 4: # 1x1 conv | |
update_reshaped = update.squeeze(-1).squeeze(-1) # Remove last two dimensions | |
else: | |
update_reshaped = update | |
exp_avg_sq_row.mul_(group["betas"][1]).add_( | |
update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] | |
) | |
exp_avg_sq_col.mul_(group["betas"][1]).add_( | |
update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] | |
) | |
# Approximate calculation | |
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | |
if layer_type == "1x1_conv": | |
# Need to reshape back to 4D | |
update = update.view(grad_shape[0], grad_shape[1], 1, 1) | |
update.mul_(grad) | |
else: | |
# 3x3 conv or other cases: use standard AdamW approach | |
exp_avg_sq = state["exp_avg_sq"] | |
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) | |
update = exp_avg_sq.rsqrt().mul_(grad) | |
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) | |
exp_avg = state["exp_avg"] | |
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) | |
# Confidence-guided strategy | |
# Calculation of instability | |
res = (update - exp_avg) ** 2 + group["eps"][1] | |
if factored: | |
exp_avg_res_row = state["exp_avg_res_row"] | |
exp_avg_res_col = state["exp_avg_res_col"] | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
# Handle dimensions | |
if len(grad_shape) == 4: # 1x1 conv | |
res_reshaped = res.squeeze(-1).squeeze(-1) # Remove last two dimensions | |
else: | |
res_reshaped = res | |
# Update residual statistics | |
exp_avg_res_row.mul_(group["betas"][2]).add_( | |
res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] | |
) | |
exp_avg_res_col.mul_(group["betas"][2]).add_( | |
res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] | |
) | |
# Approximate calculation | |
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) | |
if layer_type == "1x1_conv": | |
# 需要reshape回4D | |
res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) | |
update = res_approx.mul_(exp_avg) | |
else: | |
update = exp_avg.clone() | |
if group["weight_decay"] != 0: | |
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) | |
update.mul_(group["lr"]) | |
p.data.add_(-update) | |
return loss | |
class CAME8BitWrapper(torch.optim.Optimizer): | |
"""Implements 8bit-CAME algorithm. | |
Args: | |
params (iterable): parameters to optimize or dicts defining parameter groups | |
lr (float, optional): external learning rate (default: None) | |
eps (tuple[float, float]): regularization constants for square gradient | |
and instability respectively (default: (1e-30, 1e-16)) | |
clip_threshold (float): threshold of root-mean-square of | |
final gradient update (default: 1.0) | |
betas (tuple[float, float, float]): coefficient used for computing running averages of | |
update, square gradient and instability (default: (0.9, 0.999, 0.9999))) | |
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
block_size (int): quantization block size, larger memory efficiency, but may reduce accuracy | |
min_8bit_size (int): minimum parameter size for using 8bit quantization, only layers larger than this value will be quantized | |
Note: | |
1. Only use 8bit quantization for large Linear layers and 1x1 Conv layers | |
2. Keep all statistics (exp_avg_sq_row, etc.) in 32bit to ensure stability | |
3. Use simple min-max quantization strategy, quantize each block separately | |
""" | |
def __init__( | |
self, | |
params, | |
lr=None, | |
eps=(1e-30, 1e-16), | |
clip_threshold=1.0, | |
betas=(0.9, 0.999, 0.9999), | |
weight_decay=0.0, | |
block_size=2048, | |
min_8bit_size=16384, | |
): | |
assert lr > 0.0 | |
assert all([0.0 <= beta <= 1.0 for beta in betas]) | |
logger = get_root_logger() | |
logger.info(f"Initializing CAME8bit with block_size={block_size}, min_8bit_size={min_8bit_size}") | |
defaults = dict( | |
lr=lr, | |
eps=eps, | |
clip_threshold=clip_threshold, | |
betas=betas, | |
weight_decay=weight_decay, | |
block_size=block_size, | |
min_8bit_size=min_8bit_size, | |
) | |
super().__init__(params, defaults) | |
def print_layer_info(self, param_shape, use_8bit): | |
"""Print layer information, including parameter size and whether 8bit quantization is used | |
Args: | |
param_shape (tuple): parameter shape | |
use_8bit (bool): whether 8bit quantization is used | |
""" | |
size = np.prod(param_shape) | |
layer_type = "unknown" | |
if len(param_shape) == 1: | |
layer_type = "1D Layer" | |
elif len(param_shape) == 2: | |
layer_type = "Linear" | |
elif len(param_shape) == 4: | |
if param_shape[2] == 1 and param_shape[3] == 1: | |
layer_type = "1x1 Conv" | |
else: | |
layer_type = "Conv" | |
status = "8bit" if use_8bit else "32bit" | |
print(f"{layer_type} layer with shape {param_shape}: {size:,} params -> using {status}") | |
def _should_use_8bit(self, param_shape): | |
"""Determine if a parameter should be quantized to 8bit | |
Rules: | |
1. linear layers: parameter size > min_8bit_size | |
2. 1x1 conv layers: parameter size > min_8bit_size | |
3. other layers: use 32bit | |
""" | |
if len(param_shape) == 2: # linear layer | |
return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] | |
elif len(param_shape) == 4 and param_shape[2] == 1 and param_shape[3] == 1: | |
return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] | |
return False # other layers are not quantized | |
def _quantize_state(self, state_tensor, block_size=2048): | |
"""Quantize a state tensor to 8bit | |
Args: | |
state_tensor: tensor to be quantized | |
block_size: quantization block size | |
Returns: | |
list of quantized data blocks, each block contains: | |
- data: uint8 data | |
- scale: quantization scale | |
- min: minimum value | |
""" | |
if state_tensor.numel() <= 1: | |
return state_tensor | |
quantized_chunks = [] | |
for chunk in state_tensor.split(block_size): | |
# Calculate quantization parameters | |
chunk_min = chunk.min() | |
chunk_max = chunk.max() | |
scale = (chunk_max - chunk_min) / 255 | |
# Quantize to 0-255 range | |
quantized_chunk = ((chunk - chunk_min) / scale).round().byte() | |
quantized_chunks.append({"data": quantized_chunk, "scale": scale, "min": chunk_min}) | |
return quantized_chunks | |
def _dequantize_state(self, quantized_chunks): | |
"""Dequantize 8bit quantized data to 32bit float | |
Args: | |
quantized_chunks: list of quantized data blocks | |
Returns: | |
dequantized 32bit float tensor | |
""" | |
if not isinstance(quantized_chunks, list): | |
return quantized_chunks | |
chunks = [] | |
for chunk_dict in quantized_chunks: | |
# Dequantize: value = data * scale + min | |
chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] | |
chunks.append(chunk) | |
return torch.cat(chunks) | |
def _dequantize_state_first_step(self, quantized_chunks): | |
"""Efficient dequantization for the first step""" | |
if not isinstance(quantized_chunks, list): | |
return quantized_chunks | |
# 1. Dequantize all chunks to CPU | |
dequantized_chunks = [] | |
for chunk_dict in quantized_chunks: | |
chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] | |
dequantized_chunks.append(chunk) | |
del chunk_dict["data"] | |
torch.cuda.empty_cache() | |
# 2. Concatenate all chunks | |
result = torch.cat(dequantized_chunks) | |
del dequantized_chunks | |
torch.cuda.empty_cache() | |
return result | |
def _get_options(self, param_shape): | |
if len(param_shape) == 4: | |
if param_shape[2] == 1 and param_shape[3] == 1: | |
return True, "1x1_conv" | |
else: | |
return False, "conv" | |
elif len(param_shape) == 2: | |
return True, "linear" | |
return False, "other" | |
def _rms(self, tensor): | |
return tensor.norm(2) / (tensor.numel() ** 0.5) | |
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): | |
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) | |
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | |
return torch.mul(r_factor, c_factor) | |
def step(self, closure=None): | |
"""Perform a single optimization step | |
Main steps: | |
1. Determine if 8bit quantization is needed | |
2. Update first and second moment estimates | |
3. Compute update step | |
4. Apply confidence-guided strategy | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
if grad.dtype in {torch.float16, torch.bfloat16}: | |
grad = grad.float() | |
if grad.is_sparse: | |
raise RuntimeError("CAME8bit does not support sparse gradients.") | |
state = self.state[p] | |
grad_shape = grad.shape | |
factored, layer_type = self._get_options(grad_shape) | |
# Determine if 8bit quantization is used | |
use_8bit = self._should_use_8bit(grad_shape) | |
# State Initialization | |
if len(state) == 0: | |
self.print_layer_info(grad_shape, use_8bit) | |
state["step"] = 0 | |
# Only use 8bit quantization for large matrices | |
if use_8bit: | |
state["exp_avg"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
else: | |
state["exp_avg"] = torch.zeros_like(grad) | |
if factored: | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
# Keep row and column statistics in 32bit | |
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) | |
state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) | |
else: | |
if use_8bit: | |
state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
else: | |
state["exp_avg_sq"] = torch.zeros_like(grad) | |
else: | |
if use_8bit: | |
state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) | |
else: | |
state["exp_avg_sq"] = torch.zeros_like(grad) | |
state["RMS"] = 0 | |
state["step"] += 1 | |
state["RMS"] = self._rms(p.data) | |
exp_avg = self._dequantize_state(state["exp_avg"]) if use_8bit else state["exp_avg"] | |
update = (grad**2) + group["eps"][0] | |
if factored: | |
exp_avg_sq_row = state["exp_avg_sq_row"] # 32bit | |
exp_avg_sq_col = state["exp_avg_sq_col"] # 32bit | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
if len(grad_shape) == 4: | |
update_reshaped = update.squeeze(-1).squeeze(-1) | |
else: | |
update_reshaped = update | |
# Update row and column statistics | |
exp_avg_sq_row.mul_(group["betas"][1]).add_( | |
update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] | |
) | |
exp_avg_sq_col.mul_(group["betas"][1]).add_( | |
update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] | |
) | |
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | |
if layer_type == "1x1_conv": | |
update = update.view(grad_shape[0], grad_shape[1], 1, 1) | |
update.mul_(grad) | |
else: | |
exp_avg_sq = self._dequantize_state(state["exp_avg_sq"]) if use_8bit else state["exp_avg_sq"] | |
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) | |
if use_8bit: | |
state["exp_avg_sq"] = self._quantize_state(exp_avg_sq, group["block_size"]) | |
else: | |
state["exp_avg_sq"] = exp_avg_sq | |
update = exp_avg_sq.rsqrt().mul_(grad) | |
# Gradient clipping | |
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) | |
# Update first moment | |
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) | |
# Re-quantize (if needed) | |
if use_8bit: | |
state["exp_avg"] = self._quantize_state(exp_avg, group["block_size"]) | |
else: | |
state["exp_avg"] = exp_avg | |
# Confidence-guided strategy | |
res = (update - exp_avg) ** 2 + group["eps"][1] | |
if factored: | |
exp_avg_res_row = state["exp_avg_res_row"] # 32bit | |
exp_avg_res_col = state["exp_avg_res_col"] # 32bit | |
if layer_type == "1x1_conv" or layer_type == "linear": | |
if len(grad_shape) == 4: | |
res_reshaped = res.squeeze(-1).squeeze(-1) | |
else: | |
res_reshaped = res | |
# Update residual statistics | |
exp_avg_res_row.mul_(group["betas"][2]).add_( | |
res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] | |
) | |
exp_avg_res_col.mul_(group["betas"][2]).add_( | |
res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] | |
) | |
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) | |
if layer_type == "1x1_conv": | |
res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) | |
update = res_approx.mul_(exp_avg) | |
else: | |
update = exp_avg.clone() | |
# Weight decay | |
if group["weight_decay"] != 0: | |
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) | |
# Apply update | |
update.mul_(group["lr"]) | |
p.data.add_(-update) | |
return loss | |
def load_state_dict(self, state_dict): | |
"""Load state dict and convert relevant states to 8bit""" | |
super().load_state_dict(state_dict) | |
for state in self.state.values(): | |
for key in [ | |
"exp_avg", | |
"exp_avg_sq", | |
"exp_avg_sq_row", | |
"exp_avg_sq_col", | |
"exp_avg_res_row", | |
"exp_avg_res_col", | |
]: | |
if key in state: | |
if isinstance(state[key], list): | |
state[key] = [ | |
{ | |
"data": exp["data"].byte(), # Convert data to 8bit directly | |
"scale": exp["scale"], # Keep scale unchanged | |
"min": exp["min"], # Keep min unchanged | |
} | |
for exp in state[key] | |
] | |
elif isinstance(state[key], torch.Tensor): | |
# If tensor, keep as 32bit | |
state[key] = state[key].float() # Ensure 32bit | |
del state_dict | |
torch.cuda.empty_cache() | |