from __future__ import annotations import torch from torch import amin # Necessary for arcsin import copy import torch.nn as nn import numpy as np from scipy.optimize import curve_fit from typing import Dict, Any, Tuple, List, Callable def quantization(x, **params): return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1))) def dequantization(x, **params): return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2))) def init_linear_scale( # Symmetric scale. From the study folder x: torch.Tensor, **kwargs: Dict[str, Any], ) -> torch.Tensor: assert "bits" in kwargs, "bits must be provided." assert "params" in kwargs, "params must be provided." assert "qtz_func" in kwargs, "qtz_func must be provided." bits = kwargs.get('bits') params = kwargs.get('params') qtz_func = kwargs.get('qtz_func') x_ = x.transpose(0, 1) x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) x_ = x_.transpose(0, 1) quant_min, quant_max = get_min_max_from_bits_signed(bits) min_vals, max_vals = torch.aminmax(x_, dim=1) min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) eps = torch.finfo(torch.float32).eps abs_max_val_per_ch = torch.max(-min_vals, max_vals) scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) # Introduces some noise in scale # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything scale = scale + 0.01 * torch.randn_like(scale) return scale def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: params = { '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), } params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} if 'post_init_hook' in kwargs: kwargs['post_init_hook'](parameters=params) if 'post_train_hook' in kwargs: kwargs['post_train_hook'](parameters=params) return params ############### Numpy Qtz ############### def np_quantization(x, _0, _s): return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1))) def np_dequantization(x, _0, _s): return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2))) def fit_func(x, _0, _s): x_ = np_quantization(x, _0, _s) x_ = np_dequantization(x_, _0, _s) return x_ ############### HELPERS ############### def domain_guard( x: torch.Tensor, min: float = None, max: float = None, posinf: float = None, neginf: float = None, nan: float = None ) -> torch.Tensor: """Guard a tensor to a valid domain.""" x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) if min is not None or max is not None: x = torch.clamp(x, min=min, max=max) return x def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: """Replace a number in a tensor with another number. Args: x (torch.Tensor): The input tensor. num (float): The number to replace. to (float): The number to replace with. Returns: torch.Tensor: The tensor with the number replaced. """ return torch.where(x == num, to, x) def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: """Guard the power operation to a valid domain.""" return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: val = torch.amin(x, dim=1) return torch.ones_like(val, dtype=torch.float32, device=x.device) def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: val = torch.amin(x, dim=1) return torch.randn_like(val, dtype=torch.float32, device=x.device) def init_space_search( x: torch.Tensor, **kwargs: Dict[str, Any], ) -> torch.Tensor: def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" for _ in range(n_params * 10): # The first iteration generates 10 times more parameters yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] def _search_param(tensors: List[torch.tensor], n_params): """Takes the best parameters and generates new parameters around the mean of the best parameters.""" torch_tensors = torch.stack(tensors) min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) abs_max_val_per_ch = torch.max(-min_vals, max_vals) mean = torch.mean(torch_tensors, dim=0) for _ in range(n_params): # Generates n_params around the mean of the tensors yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean def _calc(x, qtz_func, deqtz_func, **params): x_ = x.transpose(0, 1) x_ = qtz_func(x=x_, **params) x_ = deqtz_func(x=x_, **params) x_ = x_.transpose(0, 1) return x_ assert "qtz_func" in kwargs, "qtz_func must be provided." assert "deqtz_func" in kwargs, "deqtz_func must be provided." assert "params_list" in kwargs, "params list must be provided." assert "param" in kwargs, "param must be provided." qtz_func = kwargs.get('qtz_func') deqtz_func = kwargs.get('deqtz_func') params_list = kwargs.get('params_list') param = kwargs.get('param') n_runs = 50 # Number of runs to try to find the best parameters n_random_params = 50 # Number of random parameters to generate n_best_to_pick = 5 # Number of best parameters to pick after each run max_initial = 10000 # Maximum value to initialize the parameters # Initializes the parameters base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } params = _build_initial_param(x, max_initial, n_random_params) # Performs the search for _ in range(n_runs): best_params = [] for param_ in params: try: x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) loss_ones = nn.MSELoss()(x, x_) if len(best_params) < n_best_to_pick: best_params.append((param_, loss_ones.item())) best_params = sorted(best_params, key=lambda x: x[1]) elif loss_ones < best_params[-1][1]: best_params[-1] = (param_, loss_ones.item()) best_params = sorted(best_params, key=lambda x: x[1]) except Exception: # The parameters might not be valid for the function's domain continue # Generates new parameters around the mean params = _search_param([p for p, _ in best_params], n_random_params) # Checks if the best parameter is better than the init_ones p_ones = init_ones(x, **kwargs) x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) loss_ones = nn.MSELoss()(x, x_) # Checks if the best parameter is better than the init_rand p_rand = init_rand(x, **kwargs) x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) loss_rand = nn.MSELoss()(x, x_) if loss_rand < best_params[0][1] and loss_rand < loss_ones: return p_rand elif loss_ones < best_params[0][1] and loss_ones < loss_rand: return p_ones else: return best_params[0][0] def init_linear_scale( # Symmetric scale. From the study folder x: torch.Tensor, **kwargs: Dict[str, Any], ) -> torch.Tensor: assert "bits" in kwargs, "bits must be provided." assert "params" in kwargs, "params must be provided." assert "qtz_func" in kwargs, "qtz_func must be provided." bits = kwargs.get('bits') params = kwargs.get('params') qtz_func = kwargs.get('qtz_func') x_ = x.transpose(0, 1) x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) x_ = x_.transpose(0, 1) quant_min, quant_max = get_min_max_from_bits_signed(bits) min_vals, max_vals = torch.aminmax(x_, dim=1) min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) eps = torch.finfo(torch.float32).eps abs_max_val_per_ch = torch.max(-min_vals, max_vals) scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) # Introduces some noise in scale # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything scale = scale + 0.01 * torch.randn_like(scale) return scale def init_non_linear_regression_fit( x: torch.Tensor, **kwargs: Dict[str, Any], ) -> torch.Tensor: assert "params_list" in kwargs, "params list must be provided." assert "np_fit_func" in kwargs, "np_fit_func must be provided." assert "p0" in kwargs, "p0 must be provided." np_fit_func = kwargs.get('np_fit_func') params_list = kwargs.get('params_list') p0 = kwargs.get('p0') def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): popt, _ = curve_fit( func, xdata, ydata, maxfev=1000, p0=p0, method='lm' ) return popt # 1. Needs to convert the torch tensor to numpy tensor xdata = x.cpu().numpy() # 2. Sorts the data so that it makes it easier to fit to it sorted_xdata = np.sort(xdata, axis=-1) p0 = {k: v.cpu().numpy() for k, v in p0.items()} params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order # 3. Finds the best parameters for each channel try: params = [] for i in range(sorted_xdata.shape[0]): xdata_ = sorted_xdata[i] p0_ = [p0[p][i] for p in params_list] ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) params.append(ch_params) # 4. Builds the parameters result = {} for i, p in enumerate(params_list): result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) return result except ValueError as e: print(f"Could not fit the function with error: {e}") print(f"Using fallback result...") return { k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() } def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: val = torch.amin(x, dim=1) return torch.zeros_like(val, dtype=torch.float32, device=x.device) def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: # Calculate the original minimum and maximum values min_vals, max_vals = torch.aminmax(tensor, dim=-1) x_min = torch.min(min_vals, torch.zeros_like(min_vals)) x_max = torch.max(max_vals, torch.zeros_like(max_vals)) if _max is torch.inf: # We do not need to scale the tensor. Just need to move it return torch.ones_like(x_min) # Calculate the scale factor scale = (_max - _min) / (x_max - x_min) return scale ############## Quant ############### @torch.enable_grad() def learn_parameters( x: torch.Tensor, params: Dict[str, nn.Parameter], qtz_func: nn.Module, deqtz_func: nn.Module, bits: int, target_dtype: torch.dtype, epochs: int = 1000, early_stop: bool = True, do_report: bool = False ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: # Requires gradients in the parameters for p in params.values(): p.requires_grad = True p.grad = None param_keys = list(params.keys()) param_values = list(params.values()) # Defines optimizer and loss function optimizer = torch.optim.Adam(param_values, lr=0.001) scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) loss_fn = nn.MSELoss() # Contains the best loss and the best parameters best_loss = float("inf") best_params = None # Used to stop the search early min_delta = 1e-7 acc_loss = [] percent_epochs_before_stop = 0.1 for i in range(epochs): optimizer.zero_grad() quant = quantize(x, params, qtz_func, bits, target_dtype) dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) loss = loss_fn(x, dequant) if loss.isnan() or loss.isinf(): raise Exception("Loss is NaN or Inf. Stopping the search.") loss.backward() optimizer.step() scheduler.step() acc_loss.append(loss.item()) # Reports loss every 10 steps if i % 10 == 0 and do_report: print(f"Epoch {i}: Loss {loss.item()}") # Optimizes the parameter search by storing the best loss and the parameters if loss.item() < best_loss: best_loss = loss.item() best_params = copy.deepcopy({ k: v for k, v in params.items() if k in param_keys }) # We also stop the search if the loss has not considerably during the last 10% epochs if early_stop: epochs_before_stop = int(epochs * percent_epochs_before_stop) if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: break # No longer requires gradients in the parameters for p in best_params.values(): p.requires_grad = False p.grad = None if do_report: return best_params, acc_loss else: return best_params def quantize( x: torch.Tensor, params: Dict[str, nn.Parameter], func: nn.Module, bits: int, target_dtype: torch.dtype = torch.int8 ) -> torch.Tensor: quant_min, quant_max = get_min_max_from_bits_signed(bits) x = x.transpose(0, 1) # Aligns shapes x = func(x=x, **params) x = x.transpose(0, 1) x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) return x def dequantize( x: torch.Tensor, params: Dict[str, nn.Parameter], func: nn.Module, bits: int, out_dtype: torch.dtype ) -> torch.Tensor: x = x.to(dtype=out_dtype) x = x.transpose(0, 1) x = func(x=x, **params) x = x.transpose(0, 1) return x def round_func_BPDA(input): # This is equivalent to replacing round function (non-differentiable) with # an identity function (differentiable) only when backward. forward_value = torch.round(input) out = input.clone() out.data = forward_value.data return out def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 ############## Numpy ############### def np_domain_guard( x: np.ndarray, min: float = None, max: float = None, posinf: float = None, neginf: float = None, nan: float = None ) -> np.ndarray: """Guard a tensor to a valid domain.""" x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) if min is not None or max is not None: x = np.clip(x, min, max) return x def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: """Replace a number in a tensor with another number. Args: x (np.ndarray): The input tensor. num (float): The number to replace. to (float): The number to replace with. Returns: np.ndarray: The tensor with the number replaced. """ return np.where(x == num, to, x) def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: """Guard the power operation to a valid domain.""" return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)