# Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import math from typing import Callable, Optional, Tuple import numpy as np import torch from came_pytorch import CAME from mmcv import Config from mmcv.runner import OPTIMIZER_BUILDERS, OPTIMIZERS, DefaultOptimizerConstructor from mmcv.runner import build_optimizer as mm_build_optimizer from mmcv.utils import _BatchNorm, _InstanceNorm from torch.nn import GroupNorm, LayerNorm from torch.optim.optimizer import Optimizer from .logger import get_root_logger def auto_scale_lr(effective_bs, optimizer_cfg, rule="linear", base_batch_size=256): assert rule in ["linear", "sqrt"] logger = get_root_logger() # scale by world size if rule == "sqrt": scale_ratio = math.sqrt(effective_bs / base_batch_size) elif rule == "linear": scale_ratio = effective_bs / base_batch_size optimizer_cfg["lr"] *= scale_ratio logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') return scale_ratio @OPTIMIZER_BUILDERS.register_module() class MyOptimizerConstructor(DefaultOptimizerConstructor): def add_params(self, params, module, prefix="", is_dcn_module=None): """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. Args: params (list[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. prefix (str): The prefix of the module """ # get param-wise options custom_keys = self.paramwise_cfg.get("custom_keys", {}) # first sort with alphabet order and then sort with reversed len of str # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) # special rules for norm layers and depth-wise conv layers is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) for name, param in module.named_parameters(recurse=False): base_lr = self.base_lr if name == "bias" and not (is_norm or is_dcn_module): base_lr *= bias_lr_mult # apply weight decay policies base_wd = self.base_wd if self.base_wd is not None: # norm decay if is_norm: base_wd *= norm_decay_mult # bias lr and decay elif name == "bias" and not is_dcn_module: # TODO: current bias_decay_mult will have affect on DCN base_wd *= bias_decay_mult param_group = {"params": [param]} if not param.requires_grad: param_group["requires_grad"] = False params.append(param_group) continue if bypass_duplicate and self._is_in(param_group, params): logger = get_root_logger() logger.warn(f"{prefix} is duplicate. It is skipped since " f"bypass_duplicate={bypass_duplicate}") continue # if the parameter match one of the custom keys, ignore other rules is_custom = False for key in custom_keys: if isinstance(key, tuple): scope, key_name = key else: scope, key_name = None, key if scope is not None and scope not in f"{prefix}": continue if key_name in f"{prefix}.{name}": is_custom = True if "lr_mult" in custom_keys[key]: # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': # param_group['lr'] = self.base_lr # else: param_group["lr"] = self.base_lr * custom_keys[key]["lr_mult"] elif "lr" not in param_group: param_group["lr"] = base_lr if self.base_wd is not None: if "decay_mult" in custom_keys[key]: param_group["weight_decay"] = self.base_wd * custom_keys[key]["decay_mult"] elif "weight_decay" not in param_group: param_group["weight_decay"] = base_wd if not is_custom: # bias_lr_mult affects all bias parameters # except for norm.bias dcn.conv_offset.bias if base_lr != self.base_lr: param_group["lr"] = base_lr if base_wd != self.base_wd: param_group["weight_decay"] = base_wd params.append(param_group) for child_name, child_mod in module.named_children(): child_prefix = f"{prefix}.{child_name}" if prefix else child_name self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module) def build_optimizer(model, optimizer_cfg): # default parameter-wise config logger = get_root_logger() if hasattr(model, "module"): model = model.module # set optimizer constructor optimizer_cfg.setdefault("constructor", "MyOptimizerConstructor") # parameter-wise setting: cancel weight decay for some specific modules custom_keys = dict() for name, module in model.named_modules(): if hasattr(module, "zero_weight_decay"): custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) given_cfg = optimizer_cfg.get("paramwise_cfg") if given_cfg: paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) optimizer_cfg["paramwise_cfg"] = paramwise_cfg.cfg # build optimizer optimizer = mm_build_optimizer(model, optimizer_cfg) weight_decay_groups = dict() lr_groups = dict() for group in optimizer.param_groups: if not group.get("requires_grad", True): continue lr_groups.setdefault(group["lr"], []).append(group) weight_decay_groups.setdefault(group["weight_decay"], []).append(group) learnable_count, fix_count = 0, 0 for p in model.parameters(): if p.requires_grad: learnable_count += 1 else: fix_count += 1 fix_info = f"{learnable_count} are learnable, {fix_count} are fix" lr_info = "Lr group: " + ", ".join([f"{len(group)} params with lr {lr:.5f}" for lr, group in lr_groups.items()]) wd_info = "Weight decay group: " + ", ".join( [f"{len(group)} params with weight decay {wd}" for wd, group in weight_decay_groups.items()] ) opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." logger.info(opt_info) return optimizer @OPTIMIZERS.register_module() class Lion(Optimizer): def __init__( self, params, lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, ): assert lr > 0.0 assert all([0.0 <= beta <= 1.0 for beta in betas]) defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) super().__init__(params, defaults) @staticmethod def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): # stepweight decay p.data.mul_(1 - lr * wd) # weight update update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() p.add_(update, alpha=-lr) # decay the momentum running average coefficient exp_avg.lerp_(grad, 1 - beta2) @staticmethod def exists(val): return val is not None @torch.no_grad() def step(self, closure: Optional[Callable] = None): loss = None if self.exists(closure): with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in filter(lambda p: self.exists(p.grad), group["params"]): grad, lr, wd, beta1, beta2, state = ( p.grad, group["lr"], group["weight_decay"], *group["betas"], self.state[p], ) # init state - exponential moving average of gradient values if len(state) == 0: state["exp_avg"] = torch.zeros_like(p) exp_avg = state["exp_avg"] self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2) return loss @OPTIMIZERS.register_module() class CAMEWrapper(torch.optim.Optimizer): """Implements CAME algorithm. This implementation is based on: `CAME: Confidence-guided Adaptive Memory Efficient Optimization` Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): external learning rate (default: None) eps (tuple[float, float]): regularization constants for square gradient and instability respectively (default: (1e-30, 1e-16)) clip_threshold (float): threshold of root-mean-square of final gradient update (default: 1.0) betas (tuple[float, float, float]): coefficient used for computing running averages of update, square gradient and instability (default: (0.9, 0.999, 0.9999))) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) """ def __init__( self, params, lr=None, eps=(1e-30, 1e-16), clip_threshold=1.0, betas=(0.9, 0.999, 0.9999), weight_decay=0.0, ): assert lr > 0.0 assert all([0.0 <= beta <= 1.0 for beta in betas]) defaults = dict( lr=lr, eps=eps, clip_threshold=clip_threshold, betas=betas, weight_decay=weight_decay, ) super().__init__(params, defaults) @property def supports_memory_efficient_fp16(self): return True @property def supports_flat_params(self): return False def _get_options(self, param_shape): if len(param_shape) == 4: # Conv layer if param_shape[2] == 1 and param_shape[3] == 1: # 1x1 conv return True, "1x1_conv" else: # 3x3 conv or others return False, "conv" elif len(param_shape) == 2: # Linear layer, exactly 2D return True, "linear" return False, "other" def _rms(self, tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: raise RuntimeError("CAME does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape # factored = self._get_options(grad_shape) factored, layer_type = self._get_options(grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like(grad) if factored: if layer_type == "1x1_conv" or layer_type == "linear": # 1x1 conv and linear layers can be handled the same way state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 state["step"] += 1 state["RMS"] = self._rms(p.data) update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] if layer_type == "1x1_conv" or layer_type == "linear": # Handle dimensions if len(grad_shape) == 4: # 1x1 conv update_reshaped = update.squeeze(-1).squeeze(-1) # Remove last two dimensions else: update_reshaped = update exp_avg_sq_row.mul_(group["betas"][1]).add_( update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] ) exp_avg_sq_col.mul_(group["betas"][1]).add_( update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] ) # Approximate calculation update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) if layer_type == "1x1_conv": # Need to reshape back to 4D update = update.view(grad_shape[0], grad_shape[1], 1, 1) update.mul_(grad) else: # 3x3 conv or other cases: use standard AdamW approach exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) update = exp_avg_sq.rsqrt().mul_(grad) update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) exp_avg = state["exp_avg"] exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) # Confidence-guided strategy # Calculation of instability res = (update - exp_avg) ** 2 + group["eps"][1] if factored: exp_avg_res_row = state["exp_avg_res_row"] exp_avg_res_col = state["exp_avg_res_col"] if layer_type == "1x1_conv" or layer_type == "linear": # Handle dimensions if len(grad_shape) == 4: # 1x1 conv res_reshaped = res.squeeze(-1).squeeze(-1) # Remove last two dimensions else: res_reshaped = res # Update residual statistics exp_avg_res_row.mul_(group["betas"][2]).add_( res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] ) exp_avg_res_col.mul_(group["betas"][2]).add_( res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] ) # Approximate calculation res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) if layer_type == "1x1_conv": # 需要reshape回4D res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) update = res_approx.mul_(exp_avg) else: update = exp_avg.clone() if group["weight_decay"] != 0: p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) update.mul_(group["lr"]) p.data.add_(-update) return loss @OPTIMIZERS.register_module() class CAME8BitWrapper(torch.optim.Optimizer): """Implements 8bit-CAME algorithm. Args: params (iterable): parameters to optimize or dicts defining parameter groups lr (float, optional): external learning rate (default: None) eps (tuple[float, float]): regularization constants for square gradient and instability respectively (default: (1e-30, 1e-16)) clip_threshold (float): threshold of root-mean-square of final gradient update (default: 1.0) betas (tuple[float, float, float]): coefficient used for computing running averages of update, square gradient and instability (default: (0.9, 0.999, 0.9999))) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) block_size (int): quantization block size, larger memory efficiency, but may reduce accuracy min_8bit_size (int): minimum parameter size for using 8bit quantization, only layers larger than this value will be quantized Note: 1. Only use 8bit quantization for large Linear layers and 1x1 Conv layers 2. Keep all statistics (exp_avg_sq_row, etc.) in 32bit to ensure stability 3. Use simple min-max quantization strategy, quantize each block separately """ def __init__( self, params, lr=None, eps=(1e-30, 1e-16), clip_threshold=1.0, betas=(0.9, 0.999, 0.9999), weight_decay=0.0, block_size=2048, min_8bit_size=16384, ): assert lr > 0.0 assert all([0.0 <= beta <= 1.0 for beta in betas]) logger = get_root_logger() logger.info(f"Initializing CAME8bit with block_size={block_size}, min_8bit_size={min_8bit_size}") defaults = dict( lr=lr, eps=eps, clip_threshold=clip_threshold, betas=betas, weight_decay=weight_decay, block_size=block_size, min_8bit_size=min_8bit_size, ) super().__init__(params, defaults) def print_layer_info(self, param_shape, use_8bit): """Print layer information, including parameter size and whether 8bit quantization is used Args: param_shape (tuple): parameter shape use_8bit (bool): whether 8bit quantization is used """ size = np.prod(param_shape) layer_type = "unknown" if len(param_shape) == 1: layer_type = "1D Layer" elif len(param_shape) == 2: layer_type = "Linear" elif len(param_shape) == 4: if param_shape[2] == 1 and param_shape[3] == 1: layer_type = "1x1 Conv" else: layer_type = "Conv" status = "8bit" if use_8bit else "32bit" print(f"{layer_type} layer with shape {param_shape}: {size:,} params -> using {status}") def _should_use_8bit(self, param_shape): """Determine if a parameter should be quantized to 8bit Rules: 1. linear layers: parameter size > min_8bit_size 2. 1x1 conv layers: parameter size > min_8bit_size 3. other layers: use 32bit """ if len(param_shape) == 2: # linear layer return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] elif len(param_shape) == 4 and param_shape[2] == 1 and param_shape[3] == 1: return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"] return False # other layers are not quantized def _quantize_state(self, state_tensor, block_size=2048): """Quantize a state tensor to 8bit Args: state_tensor: tensor to be quantized block_size: quantization block size Returns: list of quantized data blocks, each block contains: - data: uint8 data - scale: quantization scale - min: minimum value """ if state_tensor.numel() <= 1: return state_tensor quantized_chunks = [] for chunk in state_tensor.split(block_size): # Calculate quantization parameters chunk_min = chunk.min() chunk_max = chunk.max() scale = (chunk_max - chunk_min) / 255 # Quantize to 0-255 range quantized_chunk = ((chunk - chunk_min) / scale).round().byte() quantized_chunks.append({"data": quantized_chunk, "scale": scale, "min": chunk_min}) return quantized_chunks def _dequantize_state(self, quantized_chunks): """Dequantize 8bit quantized data to 32bit float Args: quantized_chunks: list of quantized data blocks Returns: dequantized 32bit float tensor """ if not isinstance(quantized_chunks, list): return quantized_chunks chunks = [] for chunk_dict in quantized_chunks: # Dequantize: value = data * scale + min chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] chunks.append(chunk) return torch.cat(chunks) def _dequantize_state_first_step(self, quantized_chunks): """Efficient dequantization for the first step""" if not isinstance(quantized_chunks, list): return quantized_chunks # 1. Dequantize all chunks to CPU dequantized_chunks = [] for chunk_dict in quantized_chunks: chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"] dequantized_chunks.append(chunk) del chunk_dict["data"] torch.cuda.empty_cache() # 2. Concatenate all chunks result = torch.cat(dequantized_chunks) del dequantized_chunks torch.cuda.empty_cache() return result def _get_options(self, param_shape): if len(param_shape) == 4: if param_shape[2] == 1 and param_shape[3] == 1: return True, "1x1_conv" else: return False, "conv" elif len(param_shape) == 2: return True, "linear" return False, "other" def _rms(self, tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) def step(self, closure=None): """Perform a single optimization step Main steps: 1. Determine if 8bit quantization is needed 2. Update first and second moment estimates 3. Compute update step 4. Apply confidence-guided strategy """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: raise RuntimeError("CAME8bit does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape factored, layer_type = self._get_options(grad_shape) # Determine if 8bit quantization is used use_8bit = self._should_use_8bit(grad_shape) # State Initialization if len(state) == 0: self.print_layer_info(grad_shape, use_8bit) state["step"] = 0 # Only use 8bit quantization for large matrices if use_8bit: state["exp_avg"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) else: state["exp_avg"] = torch.zeros_like(grad) if factored: if layer_type == "1x1_conv" or layer_type == "linear": # Keep row and column statistics in 32bit state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad) state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad) state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad) state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad) else: if use_8bit: state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) else: state["exp_avg_sq"] = torch.zeros_like(grad) else: if use_8bit: state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"]) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 state["step"] += 1 state["RMS"] = self._rms(p.data) exp_avg = self._dequantize_state(state["exp_avg"]) if use_8bit else state["exp_avg"] update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] # 32bit exp_avg_sq_col = state["exp_avg_sq_col"] # 32bit if layer_type == "1x1_conv" or layer_type == "linear": if len(grad_shape) == 4: update_reshaped = update.squeeze(-1).squeeze(-1) else: update_reshaped = update # Update row and column statistics exp_avg_sq_row.mul_(group["betas"][1]).add_( update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1] ) exp_avg_sq_col.mul_(group["betas"][1]).add_( update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1] ) update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) if layer_type == "1x1_conv": update = update.view(grad_shape[0], grad_shape[1], 1, 1) update.mul_(grad) else: exp_avg_sq = self._dequantize_state(state["exp_avg_sq"]) if use_8bit else state["exp_avg_sq"] exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) if use_8bit: state["exp_avg_sq"] = self._quantize_state(exp_avg_sq, group["block_size"]) else: state["exp_avg_sq"] = exp_avg_sq update = exp_avg_sq.rsqrt().mul_(grad) # Gradient clipping update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) # Update first moment exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) # Re-quantize (if needed) if use_8bit: state["exp_avg"] = self._quantize_state(exp_avg, group["block_size"]) else: state["exp_avg"] = exp_avg # Confidence-guided strategy res = (update - exp_avg) ** 2 + group["eps"][1] if factored: exp_avg_res_row = state["exp_avg_res_row"] # 32bit exp_avg_res_col = state["exp_avg_res_col"] # 32bit if layer_type == "1x1_conv" or layer_type == "linear": if len(grad_shape) == 4: res_reshaped = res.squeeze(-1).squeeze(-1) else: res_reshaped = res # Update residual statistics exp_avg_res_row.mul_(group["betas"][2]).add_( res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2] ) exp_avg_res_col.mul_(group["betas"][2]).add_( res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2] ) res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) if layer_type == "1x1_conv": res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1) update = res_approx.mul_(exp_avg) else: update = exp_avg.clone() # Weight decay if group["weight_decay"] != 0: p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) # Apply update update.mul_(group["lr"]) p.data.add_(-update) return loss def load_state_dict(self, state_dict): """Load state dict and convert relevant states to 8bit""" super().load_state_dict(state_dict) for state in self.state.values(): for key in [ "exp_avg", "exp_avg_sq", "exp_avg_sq_row", "exp_avg_sq_col", "exp_avg_res_row", "exp_avg_res_col", ]: if key in state: if isinstance(state[key], list): state[key] = [ { "data": exp["data"].byte(), # Convert data to 8bit directly "scale": exp["scale"], # Keep scale unchanged "min": exp["min"], # Keep min unchanged } for exp in state[key] ] elif isinstance(state[key], torch.Tensor): # If tensor, keep as 32bit state[key] = state[key].float() # Ensure 32bit del state_dict torch.cuda.empty_cache()