diff --git a/fn_gen/ones_noisy_scale/0/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/0/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61a8ef862058d84ff30994a1f3c304a00f13e59a Binary files /dev/null and b/fn_gen/ones_noisy_scale/0/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/0/distortion.png b/fn_gen/ones_noisy_scale/0/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..ec976f455fceb1a2a1e0291096d83912c24c71f2 Binary files /dev/null and b/fn_gen/ones_noisy_scale/0/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/0/expressions.txt b/fn_gen/ones_noisy_scale/0/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a7e5be4566beeb4727d82f95d24241966d158dc --- /dev/null +++ b/fn_gen/ones_noisy_scale/0/expressions.txt @@ -0,0 +1,2 @@ +log(_0*x)/_s +exp(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/0/fn.py b/fn_gen/ones_noisy_scale/0/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..1aecf17db3a44a110dde724f60e641b5e1aba00a --- /dev/null +++ b/fn_gen/ones_noisy_scale/0/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/0/loss.png b/fn_gen/ones_noisy_scale/0/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..bc68ecf43260091aa5f5756d99559d91a4f2e993 Binary files /dev/null and b/fn_gen/ones_noisy_scale/0/loss.png differ diff --git a/fn_gen/ones_noisy_scale/0/quantization.png b/fn_gen/ones_noisy_scale/0/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..6afd4f321148dcb57d7626cbb1a3156062d7f192 Binary files /dev/null and b/fn_gen/ones_noisy_scale/0/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/1/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/1/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0f422d04357a6e340f72e213399fab71d46c115 Binary files /dev/null and b/fn_gen/ones_noisy_scale/1/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/1/distortion.png b/fn_gen/ones_noisy_scale/1/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a90bb99bbdf2e749549736008aa3456b53315779 Binary files /dev/null and b/fn_gen/ones_noisy_scale/1/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/1/expressions.txt b/fn_gen/ones_noisy_scale/1/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..9aa25379a9d1d5a93d60659c6609b2e24e79234d --- /dev/null +++ b/fn_gen/ones_noisy_scale/1/expressions.txt @@ -0,0 +1,2 @@ +exp(_0*x)/_s +log(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/1/fn.py b/fn_gen/ones_noisy_scale/1/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ba04ab13dbc76780b48e081cf8e66b966860e2 --- /dev/null +++ b/fn_gen/ones_noisy_scale/1/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/1/loss.png b/fn_gen/ones_noisy_scale/1/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..3ed5edddb6a07c6abc7cf5a42fddb5d34be769df Binary files /dev/null and b/fn_gen/ones_noisy_scale/1/loss.png differ diff --git a/fn_gen/ones_noisy_scale/1/quantization.png b/fn_gen/ones_noisy_scale/1/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..74ed6a06ed150d82fb315e81e19eeed605b311f0 Binary files /dev/null and b/fn_gen/ones_noisy_scale/1/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/10/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/10/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03e0e2899fb592a13f48a6d95883073c0d1e05b1 Binary files /dev/null and b/fn_gen/ones_noisy_scale/10/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/10/distortion.png b/fn_gen/ones_noisy_scale/10/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a2ee371b1e8485c42a4c71d82570d92f84b5ccd0 Binary files /dev/null and b/fn_gen/ones_noisy_scale/10/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/10/expressions.txt b/fn_gen/ones_noisy_scale/10/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..dbb6da0fc54c6f23dc12daf2e2c3a395819e1bf4 --- /dev/null +++ b/fn_gen/ones_noisy_scale/10/expressions.txt @@ -0,0 +1,2 @@ +x**2/_s +sqrt(_s*x) \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/10/fn.py b/fn_gen/ones_noisy_scale/10/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b95b879e410733d747572481e3cbae6ea3b3645a --- /dev/null +++ b/fn_gen/ones_noisy_scale/10/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2))) + + +def dequantization(x, **params): + return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1)) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2))) + + +def np_dequantization(x, _s): + return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1)) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/10/loss.png b/fn_gen/ones_noisy_scale/10/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..524665796c13bd6cead300f2a498241a8e93a316 Binary files /dev/null and b/fn_gen/ones_noisy_scale/10/loss.png differ diff --git a/fn_gen/ones_noisy_scale/10/quantization.png b/fn_gen/ones_noisy_scale/10/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..00260b545b403e72bd448da1a29adcaaafc1ffb6 Binary files /dev/null and b/fn_gen/ones_noisy_scale/10/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/11/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/11/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3e68e50f5dc78ef0c085b115cdfd8623f793992 Binary files /dev/null and b/fn_gen/ones_noisy_scale/11/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/11/distortion.png b/fn_gen/ones_noisy_scale/11/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..2224d1a035fa276aa7645453d1108c16a40c6e0e Binary files /dev/null and b/fn_gen/ones_noisy_scale/11/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/11/expressions.txt b/fn_gen/ones_noisy_scale/11/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..23606e9f370f2e4adb43ed623c49d7fcaabd7355 --- /dev/null +++ b/fn_gen/ones_noisy_scale/11/expressions.txt @@ -0,0 +1,2 @@ +tan(_0*x)/_s +atan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/11/fn.py b/fn_gen/ones_noisy_scale/11/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1776bf217a0d577368d90567c69bd483061b43 --- /dev/null +++ b/fn_gen/ones_noisy_scale/11/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/11/loss.png b/fn_gen/ones_noisy_scale/11/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..406c7f8cb67caadcb4a583cda204271697b04974 Binary files /dev/null and b/fn_gen/ones_noisy_scale/11/loss.png differ diff --git a/fn_gen/ones_noisy_scale/11/quantization.png b/fn_gen/ones_noisy_scale/11/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..d4bc7af2f1ab095725781f9e8613b58726070bb9 Binary files /dev/null and b/fn_gen/ones_noisy_scale/11/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/12/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/12/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..971bca89ef98adcb83bd2c96144410fe34f31cad Binary files /dev/null and b/fn_gen/ones_noisy_scale/12/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/12/distortion.png b/fn_gen/ones_noisy_scale/12/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..4365b061d736e1737717e357c815c8fe2e9fae92 Binary files /dev/null and b/fn_gen/ones_noisy_scale/12/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/12/expressions.txt b/fn_gen/ones_noisy_scale/12/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c0b1579c06c048d5603aa39c80e392c5906a879 --- /dev/null +++ b/fn_gen/ones_noisy_scale/12/expressions.txt @@ -0,0 +1,2 @@ +cos(_0*x)/_s +acos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/12/fn.py b/fn_gen/ones_noisy_scale/12/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..52244974552fd516206c453c5a2adc46ec0281a3 --- /dev/null +++ b/fn_gen/ones_noisy_scale/12/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/12/loss.png b/fn_gen/ones_noisy_scale/12/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..dce1d80b9610417915b60780da7b7d19b15265b6 Binary files /dev/null and b/fn_gen/ones_noisy_scale/12/loss.png differ diff --git a/fn_gen/ones_noisy_scale/12/quantization.png b/fn_gen/ones_noisy_scale/12/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..387558865a81291ea9778658b357b6e7985aa159 Binary files /dev/null and b/fn_gen/ones_noisy_scale/12/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/13/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/13/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4749a37ea1841a068a8d7613df88475ffc6b9780 Binary files /dev/null and b/fn_gen/ones_noisy_scale/13/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/13/distortion.png b/fn_gen/ones_noisy_scale/13/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a0d8f23ff96aededdeea59feaf8d5d4627aa5438 Binary files /dev/null and b/fn_gen/ones_noisy_scale/13/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/13/expressions.txt b/fn_gen/ones_noisy_scale/13/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..b835531ccc3a3813012a9a9487415f4f73afabc7 --- /dev/null +++ b/fn_gen/ones_noisy_scale/13/expressions.txt @@ -0,0 +1,2 @@ +sinh(_0*x)/_s +log(_s*x - sqrt(_s**2*x**2 + 1))/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/13/fn.py b/fn_gen/ones_noisy_scale/13/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb403b38a2836c90c996284995ae8c614557f93 --- /dev/null +++ b/fn_gen/ones_noisy_scale/13/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/13/loss.png b/fn_gen/ones_noisy_scale/13/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..17c28aefee68859b270c5235c433da88ae66817f Binary files /dev/null and b/fn_gen/ones_noisy_scale/13/loss.png differ diff --git a/fn_gen/ones_noisy_scale/13/quantization.png b/fn_gen/ones_noisy_scale/13/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..8d53861e04d5aa6b8922d521961c0da1e078a344 Binary files /dev/null and b/fn_gen/ones_noisy_scale/13/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/14/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/14/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0cae601ed3d624aec8d6b28d4e1aabdbe86849c Binary files /dev/null and b/fn_gen/ones_noisy_scale/14/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/14/distortion.png b/fn_gen/ones_noisy_scale/14/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..281d1e84b7ad0ef19f60432176a35a0ba68f9559 Binary files /dev/null and b/fn_gen/ones_noisy_scale/14/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/14/expressions.txt b/fn_gen/ones_noisy_scale/14/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c545adce8b3c320e195336b81461c79d0cc385e6 --- /dev/null +++ b/fn_gen/ones_noisy_scale/14/expressions.txt @@ -0,0 +1,2 @@ +asinh(_0*x)/_s +sinh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/14/fn.py b/fn_gen/ones_noisy_scale/14/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..508f4f75c98c0cd800dcd53c5080275bc130e7f4 --- /dev/null +++ b/fn_gen/ones_noisy_scale/14/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/14/loss.png b/fn_gen/ones_noisy_scale/14/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..d8d4e1592db50c75f469670700715e8571fc5f31 Binary files /dev/null and b/fn_gen/ones_noisy_scale/14/loss.png differ diff --git a/fn_gen/ones_noisy_scale/14/quantization.png b/fn_gen/ones_noisy_scale/14/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..14b1f23875a08d19089e72fb0426bbea4399fa18 Binary files /dev/null and b/fn_gen/ones_noisy_scale/14/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/15/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/15/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e60de42a5bd9563eafd2794bd8f9391b48043635 Binary files /dev/null and b/fn_gen/ones_noisy_scale/15/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/15/distortion.png b/fn_gen/ones_noisy_scale/15/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..deded020df990a5b7989d68fec9bd05580f86f5a Binary files /dev/null and b/fn_gen/ones_noisy_scale/15/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/15/expressions.txt b/fn_gen/ones_noisy_scale/15/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..3758ee2a62aa8d95c3b7da1dd3fafa11b027ad9b --- /dev/null +++ b/fn_gen/ones_noisy_scale/15/expressions.txt @@ -0,0 +1,2 @@ +cosh(_0*x)/_s +log(_s*x - sqrt(_s**2*x**2 - 1))/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/15/fn.py b/fn_gen/ones_noisy_scale/15/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2e972908033242abcdf3d2192c341a5f20a6925e --- /dev/null +++ b/fn_gen/ones_noisy_scale/15/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cosh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(-1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cosh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(-1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/15/loss.png b/fn_gen/ones_noisy_scale/15/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..dda793f8333640383fad15a09e111a32a08b510f Binary files /dev/null and b/fn_gen/ones_noisy_scale/15/loss.png differ diff --git a/fn_gen/ones_noisy_scale/15/quantization.png b/fn_gen/ones_noisy_scale/15/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..0d517c8d22981e08e5efddd10c392b054bae07f6 Binary files /dev/null and b/fn_gen/ones_noisy_scale/15/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/16/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/16/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a6afc298a2efe7419ad2607d18222aaf6d7f80 Binary files /dev/null and b/fn_gen/ones_noisy_scale/16/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/16/distortion.png b/fn_gen/ones_noisy_scale/16/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..953359ce6c2245c685ce4a5f593c8078d3adf1e3 Binary files /dev/null and b/fn_gen/ones_noisy_scale/16/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/16/expressions.txt b/fn_gen/ones_noisy_scale/16/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..74791fc40576643d62f6366a8b4eda20eb1ad252 --- /dev/null +++ b/fn_gen/ones_noisy_scale/16/expressions.txt @@ -0,0 +1,2 @@ +x**3/_s +(_s*x)**(1/3) \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/16/fn.py b/fn_gen/ones_noisy_scale/16/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2e79759584ebbb7bb108d53b863065beb47370ab --- /dev/null +++ b/fn_gen/ones_noisy_scale/16/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3))) + + +def dequantization(x, **params): + return guarded_torch_power((params['_s'] * x), 1 / 3) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3))) + + +def np_dequantization(x, _s): + return np_guarded_power((_s * x), 1 / 3) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/16/loss.png b/fn_gen/ones_noisy_scale/16/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..4706738cf709a0703e475aa165994e26433653b4 Binary files /dev/null and b/fn_gen/ones_noisy_scale/16/loss.png differ diff --git a/fn_gen/ones_noisy_scale/16/quantization.png b/fn_gen/ones_noisy_scale/16/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..1e9c79008088817e6a2fb2280b26e4a2b87bafef Binary files /dev/null and b/fn_gen/ones_noisy_scale/16/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/17/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/17/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7a694e7f926aa31bd8e746094c5b2e01d0c6e5 Binary files /dev/null and b/fn_gen/ones_noisy_scale/17/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/17/distortion.png b/fn_gen/ones_noisy_scale/17/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..c1796af77557a1d74c389a36d5f05d3effae2efc Binary files /dev/null and b/fn_gen/ones_noisy_scale/17/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/17/expressions.txt b/fn_gen/ones_noisy_scale/17/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed99293c42843616c361d59b23d32ae553cc0f8d --- /dev/null +++ b/fn_gen/ones_noisy_scale/17/expressions.txt @@ -0,0 +1,2 @@ +atanh(_0*x)/_s +tanh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/17/fn.py b/fn_gen/ones_noisy_scale/17/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1338e01fc9178ebf21769486f25c21ee2c463c --- /dev/null +++ b/fn_gen/ones_noisy_scale/17/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/17/loss.png b/fn_gen/ones_noisy_scale/17/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..c7a7d3487ff720c9b4a17eda35f728340a0fa44c Binary files /dev/null and b/fn_gen/ones_noisy_scale/17/loss.png differ diff --git a/fn_gen/ones_noisy_scale/17/quantization.png b/fn_gen/ones_noisy_scale/17/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..ac2ff3d78afc189aa2f7a74bd6025e8766915345 Binary files /dev/null and b/fn_gen/ones_noisy_scale/17/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/18/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/18/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb42e4bdd8c05d5540230cbb36f63c59aabed23d Binary files /dev/null and b/fn_gen/ones_noisy_scale/18/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/18/distortion.png b/fn_gen/ones_noisy_scale/18/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..9d0fa242a7b0ba69b5fc34c0c7d1220cca146016 Binary files /dev/null and b/fn_gen/ones_noisy_scale/18/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/18/expressions.txt b/fn_gen/ones_noisy_scale/18/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..576ec6a351e26f9982eb17e394804ca906d4b067 --- /dev/null +++ b/fn_gen/ones_noisy_scale/18/expressions.txt @@ -0,0 +1,2 @@ +acos(_0*x)/_s +cos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/18/fn.py b/fn_gen/ones_noisy_scale/18/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..af103d8d238913e87f3aed9341bd7316f012cff5 --- /dev/null +++ b/fn_gen/ones_noisy_scale/18/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/18/loss.png b/fn_gen/ones_noisy_scale/18/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..634273cb597733124b3a0d1e6c587455d64db0a3 Binary files /dev/null and b/fn_gen/ones_noisy_scale/18/loss.png differ diff --git a/fn_gen/ones_noisy_scale/18/quantization.png b/fn_gen/ones_noisy_scale/18/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..d7430266da44c401673d5a1e4428597a2ad4deda Binary files /dev/null and b/fn_gen/ones_noisy_scale/18/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/2/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/2/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efbc95ca8f2780bd0229adf8c5f3d5d86dca533d Binary files /dev/null and b/fn_gen/ones_noisy_scale/2/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/2/distortion.png b/fn_gen/ones_noisy_scale/2/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..80302720bf11851fc55b66b7c88d7af08cdd216b Binary files /dev/null and b/fn_gen/ones_noisy_scale/2/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/2/expressions.txt b/fn_gen/ones_noisy_scale/2/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..03413827fa8f4c8ad49a40b543460cf31d1ce803 --- /dev/null +++ b/fn_gen/ones_noisy_scale/2/expressions.txt @@ -0,0 +1,2 @@ +asin(_0*x)/_s +sin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/2/fn.py b/fn_gen/ones_noisy_scale/2/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..5d831428362c2c001e86f08ed62a48ac56ddd226 --- /dev/null +++ b/fn_gen/ones_noisy_scale/2/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/2/loss.png b/fn_gen/ones_noisy_scale/2/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..83a11f7f4ce0bcefb9152b10a49e6cbd09b1d5b7 Binary files /dev/null and b/fn_gen/ones_noisy_scale/2/loss.png differ diff --git a/fn_gen/ones_noisy_scale/2/quantization.png b/fn_gen/ones_noisy_scale/2/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..6cf042233e4cbcf40e9974745fd24023ad909fb5 Binary files /dev/null and b/fn_gen/ones_noisy_scale/2/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/3/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/3/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001cf3716e51cd99cb27741c5f33be94168f0c2d Binary files /dev/null and b/fn_gen/ones_noisy_scale/3/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/3/distortion.png b/fn_gen/ones_noisy_scale/3/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..3a33db8f97b991d5b0d832624b1580ff112c67b2 Binary files /dev/null and b/fn_gen/ones_noisy_scale/3/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/3/expressions.txt b/fn_gen/ones_noisy_scale/3/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7b68c388fdf6e1b6e2be8076f1d4b8d7bcef4f9 --- /dev/null +++ b/fn_gen/ones_noisy_scale/3/expressions.txt @@ -0,0 +1,2 @@ +(_0*x)**(1/3)/_s +_s**3*x**3/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/3/fn.py b/fn_gen/ones_noisy_scale/3/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..7916af84f376dedc5b826cd809da490a0e03fa9a --- /dev/null +++ b/fn_gen/ones_noisy_scale/3/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3)) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3)) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/3/loss.png b/fn_gen/ones_noisy_scale/3/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..9a84d0ef31d88bf03e1a3afa45759645bd2cddd3 Binary files /dev/null and b/fn_gen/ones_noisy_scale/3/loss.png differ diff --git a/fn_gen/ones_noisy_scale/3/quantization.png b/fn_gen/ones_noisy_scale/3/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..9c3191093eb2df6057236982a0f9ca9b38b5e57f Binary files /dev/null and b/fn_gen/ones_noisy_scale/3/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/4/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/4/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0abaaa59a90433a3081d84c8c0d33e3d1c722a99 Binary files /dev/null and b/fn_gen/ones_noisy_scale/4/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/4/distortion.png b/fn_gen/ones_noisy_scale/4/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..b6e615529b667ac9aa4f13f30fb85231b828f2cb Binary files /dev/null and b/fn_gen/ones_noisy_scale/4/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/4/expressions.txt b/fn_gen/ones_noisy_scale/4/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d6553d091cd1d343d7aa9b52b85ef6ec88ea854 --- /dev/null +++ b/fn_gen/ones_noisy_scale/4/expressions.txt @@ -0,0 +1,2 @@ +x/_s +_s*x \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/4/fn.py b/fn_gen/ones_noisy_scale/4/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e9750eabeb36caefecea6401df6ce9a0f064da --- /dev/null +++ b/fn_gen/ones_noisy_scale/4/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000))) + + +def dequantization(x, **params): + return (params['_s'] * x) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (x * np.divide(1, np_replace_num(_s, num=0, to=10000))) + + +def np_dequantization(x, _s): + return (_s * x) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/4/loss.png b/fn_gen/ones_noisy_scale/4/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..8bbab0d13dbff130888c3c80bac8719f812579bc Binary files /dev/null and b/fn_gen/ones_noisy_scale/4/loss.png differ diff --git a/fn_gen/ones_noisy_scale/4/quantization.png b/fn_gen/ones_noisy_scale/4/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..d5649883e466d615ca0bb98d6bfba1ceb6019d43 Binary files /dev/null and b/fn_gen/ones_noisy_scale/4/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/5/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/5/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec06310c9f4f73146d3bfee4853294c56e1f2e6c Binary files /dev/null and b/fn_gen/ones_noisy_scale/5/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/5/distortion.png b/fn_gen/ones_noisy_scale/5/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..cf69fc4e626c989beff0f128b3122bfc6b7e4a25 Binary files /dev/null and b/fn_gen/ones_noisy_scale/5/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/5/expressions.txt b/fn_gen/ones_noisy_scale/5/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa32b575e8c654dbc457c94f36222e70d86dc940 --- /dev/null +++ b/fn_gen/ones_noisy_scale/5/expressions.txt @@ -0,0 +1,2 @@ +atan(_0*x)/_s +tan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/5/fn.py b/fn_gen/ones_noisy_scale/5/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..dc40bcb61599701bbd58e4092c2cd9271b20e1d7 --- /dev/null +++ b/fn_gen/ones_noisy_scale/5/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/5/loss.png b/fn_gen/ones_noisy_scale/5/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..040ab1422644fe835cd32bdf81f3fe4b0c918daf Binary files /dev/null and b/fn_gen/ones_noisy_scale/5/loss.png differ diff --git a/fn_gen/ones_noisy_scale/5/quantization.png b/fn_gen/ones_noisy_scale/5/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..266bbaedc7dd8ffa791bf2fd5da097cd18a62e68 Binary files /dev/null and b/fn_gen/ones_noisy_scale/5/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/6/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/6/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f89b833eddcc28eecf73d71a52e3536370899240 Binary files /dev/null and b/fn_gen/ones_noisy_scale/6/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/6/distortion.png b/fn_gen/ones_noisy_scale/6/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..cc413167e551a11f5f072d7461772582980af699 Binary files /dev/null and b/fn_gen/ones_noisy_scale/6/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/6/expressions.txt b/fn_gen/ones_noisy_scale/6/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec55493201f7b2b8effaefed75e0a9258fc25c56 --- /dev/null +++ b/fn_gen/ones_noisy_scale/6/expressions.txt @@ -0,0 +1,2 @@ +tanh(_0*x)/_s +log((-_s*x - 1)/(_s*x - 1))/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/6/fn.py b/fn_gen/ones_noisy_scale/6/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5684f903c0d668936bac02695ef473ce09c6ba --- /dev/null +++ b/fn_gen/ones_noisy_scale/6/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/6/loss.png b/fn_gen/ones_noisy_scale/6/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..4e5bb4cfb76326e88c13e29265883a6b1187756b Binary files /dev/null and b/fn_gen/ones_noisy_scale/6/loss.png differ diff --git a/fn_gen/ones_noisy_scale/6/quantization.png b/fn_gen/ones_noisy_scale/6/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..effca0c0d0ba973927289ce219f9a0a84f05e46d Binary files /dev/null and b/fn_gen/ones_noisy_scale/6/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/7/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/7/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7583d0f21a3cf5d4b2ba9cefcb9b7a88e0c3f219 Binary files /dev/null and b/fn_gen/ones_noisy_scale/7/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/7/distortion.png b/fn_gen/ones_noisy_scale/7/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..2cce44c32d8af01f6a09010f4e11fe358d6f5dcb Binary files /dev/null and b/fn_gen/ones_noisy_scale/7/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/7/expressions.txt b/fn_gen/ones_noisy_scale/7/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ecd6e238827dcdb95f4bcb390c1c300696f34254 --- /dev/null +++ b/fn_gen/ones_noisy_scale/7/expressions.txt @@ -0,0 +1,2 @@ +sin(_0*x)/_s +asin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/7/fn.py b/fn_gen/ones_noisy_scale/7/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..8444881f5291708efc726572af1c59eecff7700c --- /dev/null +++ b/fn_gen/ones_noisy_scale/7/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sin((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.asin(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sin((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arcsin(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/7/loss.png b/fn_gen/ones_noisy_scale/7/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..9b7aa2573c6e84b040aa500a7aa839da7e609e43 Binary files /dev/null and b/fn_gen/ones_noisy_scale/7/loss.png differ diff --git a/fn_gen/ones_noisy_scale/7/quantization.png b/fn_gen/ones_noisy_scale/7/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..6efe9272f4f81e7a907cd22e58d968d3ac4c1dd8 Binary files /dev/null and b/fn_gen/ones_noisy_scale/7/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/8/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/8/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843f097ed88e3201c7a1c9cd0fc77a5b9dbe0365 Binary files /dev/null and b/fn_gen/ones_noisy_scale/8/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/8/distortion.png b/fn_gen/ones_noisy_scale/8/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..565474acd2815137b1c64fbba0d6b08c5e3792fb Binary files /dev/null and b/fn_gen/ones_noisy_scale/8/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/8/expressions.txt b/fn_gen/ones_noisy_scale/8/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8458af52eb4cfce21cf8459f3c454003cd78158 --- /dev/null +++ b/fn_gen/ones_noisy_scale/8/expressions.txt @@ -0,0 +1,2 @@ +sqrt(_0*x)/_s +_s**2*x**2/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/8/fn.py b/fn_gen/ones_noisy_scale/8/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a904a9968676f1b5dc7efd7ac341ee699e7e74 --- /dev/null +++ b/fn_gen/ones_noisy_scale/8/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/8/loss.png b/fn_gen/ones_noisy_scale/8/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..fcfbecfcb13a7d90161aad379c5b6c9f7709e24c Binary files /dev/null and b/fn_gen/ones_noisy_scale/8/loss.png differ diff --git a/fn_gen/ones_noisy_scale/8/quantization.png b/fn_gen/ones_noisy_scale/8/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..e2ff36182f45ca4dec1e693e2b6aa826782c0ebf Binary files /dev/null and b/fn_gen/ones_noisy_scale/8/quantization.png differ diff --git a/fn_gen/ones_noisy_scale/9/__pycache__/fn.cpython-310.pyc b/fn_gen/ones_noisy_scale/9/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2676dda7ba667b1bf5e896e504cf7abb2ea60c5e Binary files /dev/null and b/fn_gen/ones_noisy_scale/9/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/ones_noisy_scale/9/distortion.png b/fn_gen/ones_noisy_scale/9/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..dd15fe9dd11c4ac984071b7e1aa590d6c8ea5c75 Binary files /dev/null and b/fn_gen/ones_noisy_scale/9/distortion.png differ diff --git a/fn_gen/ones_noisy_scale/9/expressions.txt b/fn_gen/ones_noisy_scale/9/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a7abbbdac98c7d53123fe0b9807e7644bc00acf --- /dev/null +++ b/fn_gen/ones_noisy_scale/9/expressions.txt @@ -0,0 +1,2 @@ +acosh(_0*x)/_s +cosh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/ones_noisy_scale/9/fn.py b/fn_gen/ones_noisy_scale/9/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..4358a8687f385c2bc3ff53fdfee8238415b82add --- /dev/null +++ b/fn_gen/ones_noisy_scale/9/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/ones_noisy_scale/9/loss.png b/fn_gen/ones_noisy_scale/9/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..472e8d0c685b2fabf893a41f6ac178edebc5acee Binary files /dev/null and b/fn_gen/ones_noisy_scale/9/loss.png differ diff --git a/fn_gen/ones_noisy_scale/9/quantization.png b/fn_gen/ones_noisy_scale/9/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..14c2e2f410a2f4b55c06ad5e1d223d7c4e054aee Binary files /dev/null and b/fn_gen/ones_noisy_scale/9/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/0/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/0/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f42d1087256af52ab41f1550ed8929d7b44926 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/0/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/0/distortion.png b/fn_gen/rnd_noisy_scale/0/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..052f4a35c72accdd430dc8b240e91e0b457199cc Binary files /dev/null and b/fn_gen/rnd_noisy_scale/0/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/0/expressions.txt b/fn_gen/rnd_noisy_scale/0/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec55493201f7b2b8effaefed75e0a9258fc25c56 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/0/expressions.txt @@ -0,0 +1,2 @@ +tanh(_0*x)/_s +log((-_s*x - 1)/(_s*x - 1))/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/0/fn.py b/fn_gen/rnd_noisy_scale/0/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..438e32172d7209d095e077d39aaec840b866b921 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/0/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/0/loss.png b/fn_gen/rnd_noisy_scale/0/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..972fc1200b43b64201a2e9e8797c75fe85c08706 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/0/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/0/quantization.png b/fn_gen/rnd_noisy_scale/0/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..f1e18ffbfff73bc68b07362aa3f7d92b702312a1 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/0/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/1/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/1/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5339bb7b9ef19848ffc931a4a74291ed66eaf846 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/1/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/1/distortion.png b/fn_gen/rnd_noisy_scale/1/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..43c9df9a09298e832e330e452ad718e425ec9786 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/1/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/1/expressions.txt b/fn_gen/rnd_noisy_scale/1/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7b68c388fdf6e1b6e2be8076f1d4b8d7bcef4f9 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/1/expressions.txt @@ -0,0 +1,2 @@ +(_0*x)**(1/3)/_s +_s**3*x**3/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/1/fn.py b/fn_gen/rnd_noisy_scale/1/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..ab429fec271fb64442e8e942f6291ec85bafc183 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/1/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3)) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3)) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/1/loss.png b/fn_gen/rnd_noisy_scale/1/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..7c431df0fa61e264987d965bb168a4a275adac65 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/1/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/1/quantization.png b/fn_gen/rnd_noisy_scale/1/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..7d06a43308bc76a1772544fa4038165439731d8e Binary files /dev/null and b/fn_gen/rnd_noisy_scale/1/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/10/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/10/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9be01b6e40f37a2b5910c30c987b4d910c4abfcb Binary files /dev/null and b/fn_gen/rnd_noisy_scale/10/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/10/distortion.png b/fn_gen/rnd_noisy_scale/10/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..87c2b1f18e0c570cb99d9505d3ce4c1d79470782 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/10/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/10/expressions.txt b/fn_gen/rnd_noisy_scale/10/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ecd6e238827dcdb95f4bcb390c1c300696f34254 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/10/expressions.txt @@ -0,0 +1,2 @@ +sin(_0*x)/_s +asin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/10/fn.py b/fn_gen/rnd_noisy_scale/10/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..121f572a3a2586b4e0805988aa9c731bf5570652 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/10/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sin((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.asin(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sin((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arcsin(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/10/loss.png b/fn_gen/rnd_noisy_scale/10/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..da6bcd99f1df0c1e965568f381c3b16fe20c46d7 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/10/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/10/quantization.png b/fn_gen/rnd_noisy_scale/10/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..d96531a1e7be7d2b592baf909be5e4720827e9af Binary files /dev/null and b/fn_gen/rnd_noisy_scale/10/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/11/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/11/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbfc55690d8290745117194d6846c30ef20ebd1b Binary files /dev/null and b/fn_gen/rnd_noisy_scale/11/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/11/distortion.png b/fn_gen/rnd_noisy_scale/11/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..b62ed2db56eb8925c3cb7c3b867c5c598648a976 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/11/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/11/expressions.txt b/fn_gen/rnd_noisy_scale/11/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a7abbbdac98c7d53123fe0b9807e7644bc00acf --- /dev/null +++ b/fn_gen/rnd_noisy_scale/11/expressions.txt @@ -0,0 +1,2 @@ +acosh(_0*x)/_s +cosh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/11/fn.py b/fn_gen/rnd_noisy_scale/11/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c453e6916db059af0893e9ca392361d71c2ab9fe --- /dev/null +++ b/fn_gen/rnd_noisy_scale/11/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/11/loss.png b/fn_gen/rnd_noisy_scale/11/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..729869cb15a0cb759acab4d771e7e582d11417ac Binary files /dev/null and b/fn_gen/rnd_noisy_scale/11/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/11/quantization.png b/fn_gen/rnd_noisy_scale/11/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..14c2e2f410a2f4b55c06ad5e1d223d7c4e054aee Binary files /dev/null and b/fn_gen/rnd_noisy_scale/11/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/12/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/12/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a14564d7d8feba31b861b5a8c81bb95eb7fdf3d Binary files /dev/null and b/fn_gen/rnd_noisy_scale/12/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/12/distortion.png b/fn_gen/rnd_noisy_scale/12/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..c0b4cd6314d0a7d6b37afed633372e08b50e3568 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/12/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/12/expressions.txt b/fn_gen/rnd_noisy_scale/12/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..576ec6a351e26f9982eb17e394804ca906d4b067 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/12/expressions.txt @@ -0,0 +1,2 @@ +acos(_0*x)/_s +cos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/12/fn.py b/fn_gen/rnd_noisy_scale/12/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb729bb3feb058e7dfde919e58db23c1394678e --- /dev/null +++ b/fn_gen/rnd_noisy_scale/12/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/12/loss.png b/fn_gen/rnd_noisy_scale/12/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..12d05a54c149137ab42b55a3cc954f8ef3d37c61 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/12/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/12/quantization.png b/fn_gen/rnd_noisy_scale/12/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..5aac78f21a048d81b9d9bd229b8064c35d6268d6 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/12/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/13/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/13/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca74ff715cb2bc57e095633c18de697d73bf2f9 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/13/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/13/distortion.png b/fn_gen/rnd_noisy_scale/13/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..0cd1b6511145b4f371ba5ce4370c7ff3a685fb0a Binary files /dev/null and b/fn_gen/rnd_noisy_scale/13/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/13/expressions.txt b/fn_gen/rnd_noisy_scale/13/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..9aa25379a9d1d5a93d60659c6609b2e24e79234d --- /dev/null +++ b/fn_gen/rnd_noisy_scale/13/expressions.txt @@ -0,0 +1,2 @@ +exp(_0*x)/_s +log(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/13/fn.py b/fn_gen/rnd_noisy_scale/13/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..1771e0a1211125efb14f3787e4b09083233d3201 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/13/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/13/loss.png b/fn_gen/rnd_noisy_scale/13/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..58efa86b5f9d64617241a7bde90e65b586719f7c Binary files /dev/null and b/fn_gen/rnd_noisy_scale/13/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/13/quantization.png b/fn_gen/rnd_noisy_scale/13/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..58fd95a8d861ddfad3f699d11e728c23c2d1c742 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/13/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/14/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/14/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af48938b9d0c4a4b866d6a375e44ee5f39b94a81 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/14/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/14/distortion.png b/fn_gen/rnd_noisy_scale/14/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..61dcafe63da66c3efac260b587601dba698a7574 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/14/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/14/expressions.txt b/fn_gen/rnd_noisy_scale/14/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..03413827fa8f4c8ad49a40b543460cf31d1ce803 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/14/expressions.txt @@ -0,0 +1,2 @@ +asin(_0*x)/_s +sin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/14/fn.py b/fn_gen/rnd_noisy_scale/14/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eccfd00a4564a0df3096b2c88f805e1d981af8 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/14/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/14/loss.png b/fn_gen/rnd_noisy_scale/14/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..53a8d9428fc8ab2becb21e28deaf5034b04b8210 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/14/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/14/quantization.png b/fn_gen/rnd_noisy_scale/14/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..e68603ba379ef3c40977966fe5cce12ecfc78231 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/14/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/15/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/15/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dbb9d7f7d6393412033b20668fded30630f366b Binary files /dev/null and b/fn_gen/rnd_noisy_scale/15/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/15/distortion.png b/fn_gen/rnd_noisy_scale/15/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..df3e9cad3a98c667b7ce94bef991dba8e8bf00ef Binary files /dev/null and b/fn_gen/rnd_noisy_scale/15/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/15/expressions.txt b/fn_gen/rnd_noisy_scale/15/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed99293c42843616c361d59b23d32ae553cc0f8d --- /dev/null +++ b/fn_gen/rnd_noisy_scale/15/expressions.txt @@ -0,0 +1,2 @@ +atanh(_0*x)/_s +tanh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/15/fn.py b/fn_gen/rnd_noisy_scale/15/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..f509472ab7c3700d127910a502fd11df9c04e5bd --- /dev/null +++ b/fn_gen/rnd_noisy_scale/15/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/15/loss.png b/fn_gen/rnd_noisy_scale/15/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..a506929d40b7efdf0dad8c4a75141ee93bac3e2b Binary files /dev/null and b/fn_gen/rnd_noisy_scale/15/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/15/quantization.png b/fn_gen/rnd_noisy_scale/15/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..34e5da1e47cfcdd2e41a3df752397bbaa5abccef Binary files /dev/null and b/fn_gen/rnd_noisy_scale/15/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/16/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/16/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae24e75911955dbe03220adcbba0c131d6bbcb9 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/16/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/16/distortion.png b/fn_gen/rnd_noisy_scale/16/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..8d7ff09f6ab42ec56cd591b7f8477be63d208e68 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/16/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/16/expressions.txt b/fn_gen/rnd_noisy_scale/16/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c545adce8b3c320e195336b81461c79d0cc385e6 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/16/expressions.txt @@ -0,0 +1,2 @@ +asinh(_0*x)/_s +sinh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/16/fn.py b/fn_gen/rnd_noisy_scale/16/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c2da72a66227ce4d96a9f8976fe3ae89cd08f082 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/16/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/16/loss.png b/fn_gen/rnd_noisy_scale/16/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..70709965744ef16372a530d9950474506ee044e3 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/16/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/16/quantization.png b/fn_gen/rnd_noisy_scale/16/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..6360af68c9bb56e777578eaba39094b54483151e Binary files /dev/null and b/fn_gen/rnd_noisy_scale/16/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/17/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/17/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f5e685aca7c76bd4f5b2dd7bb70d21c65192aa7 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/17/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/17/distortion.png b/fn_gen/rnd_noisy_scale/17/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..ae8c8d56ea3b7f0445691a6ff58e4a6ae3da2035 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/17/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/17/expressions.txt b/fn_gen/rnd_noisy_scale/17/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..b835531ccc3a3813012a9a9487415f4f73afabc7 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/17/expressions.txt @@ -0,0 +1,2 @@ +sinh(_0*x)/_s +log(_s*x - sqrt(_s**2*x**2 + 1))/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/17/fn.py b/fn_gen/rnd_noisy_scale/17/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..35511223ec9a4081e90f781b743a214a6ae727de --- /dev/null +++ b/fn_gen/rnd_noisy_scale/17/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/17/loss.png b/fn_gen/rnd_noisy_scale/17/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..4fe3f7dc5fb18b1f4e3e330bb6979cd0783bc799 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/17/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/17/quantization.png b/fn_gen/rnd_noisy_scale/17/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..adf613477d7a0a443ff89c4735c63620484080aa Binary files /dev/null and b/fn_gen/rnd_noisy_scale/17/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/18/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/18/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d529a64589a81a0b9ae29e8bc002a0f1b2ce103a Binary files /dev/null and b/fn_gen/rnd_noisy_scale/18/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/18/distortion.png b/fn_gen/rnd_noisy_scale/18/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..0e61a53c144793997a25eb8149cbb76cb24e3632 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/18/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/18/expressions.txt b/fn_gen/rnd_noisy_scale/18/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c0b1579c06c048d5603aa39c80e392c5906a879 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/18/expressions.txt @@ -0,0 +1,2 @@ +cos(_0*x)/_s +acos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/18/fn.py b/fn_gen/rnd_noisy_scale/18/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..aae0963cb7272048c537c2075fe29c4f8e4fd242 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/18/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/18/loss.png b/fn_gen/rnd_noisy_scale/18/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..6927b4b9cbf506908185de1d81f720db75c39900 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/18/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/18/quantization.png b/fn_gen/rnd_noisy_scale/18/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..772298aba60007b5310a252b10a4ff4091bf851f Binary files /dev/null and b/fn_gen/rnd_noisy_scale/18/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/2/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/2/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3fb3a62a3ad3c3ef8e9f8aa957aaccd972d160f Binary files /dev/null and b/fn_gen/rnd_noisy_scale/2/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/2/distortion.png b/fn_gen/rnd_noisy_scale/2/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..cf991008cac8e46ae1089e265cfee2e2078a7f69 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/2/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/2/expressions.txt b/fn_gen/rnd_noisy_scale/2/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..3758ee2a62aa8d95c3b7da1dd3fafa11b027ad9b --- /dev/null +++ b/fn_gen/rnd_noisy_scale/2/expressions.txt @@ -0,0 +1,2 @@ +cosh(_0*x)/_s +log(_s*x - sqrt(_s**2*x**2 - 1))/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/2/fn.py b/fn_gen/rnd_noisy_scale/2/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c38c1d4e3594cb14ed572eb3eb1062fed49231bc --- /dev/null +++ b/fn_gen/rnd_noisy_scale/2/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cosh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(-1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cosh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(-1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/2/loss.png b/fn_gen/rnd_noisy_scale/2/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..bd6ad6df87ddf01b0dfd67d9b41078dfb6b167bb Binary files /dev/null and b/fn_gen/rnd_noisy_scale/2/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/2/quantization.png b/fn_gen/rnd_noisy_scale/2/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..8f59c4f41110f66464e7d55c66067b10bfa26a82 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/2/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/3/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/3/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8b825147a0457ea849e1ddc271b14dcd909c7ac Binary files /dev/null and b/fn_gen/rnd_noisy_scale/3/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/3/distortion.png b/fn_gen/rnd_noisy_scale/3/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..953359ce6c2245c685ce4a5f593c8078d3adf1e3 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/3/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/3/expressions.txt b/fn_gen/rnd_noisy_scale/3/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..74791fc40576643d62f6366a8b4eda20eb1ad252 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/3/expressions.txt @@ -0,0 +1,2 @@ +x**3/_s +(_s*x)**(1/3) \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/3/fn.py b/fn_gen/rnd_noisy_scale/3/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2e79759584ebbb7bb108d53b863065beb47370ab --- /dev/null +++ b/fn_gen/rnd_noisy_scale/3/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3))) + + +def dequantization(x, **params): + return guarded_torch_power((params['_s'] * x), 1 / 3) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3))) + + +def np_dequantization(x, _s): + return np_guarded_power((_s * x), 1 / 3) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/3/loss.png b/fn_gen/rnd_noisy_scale/3/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..a85e119478664f129d2852b280001fe88411e357 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/3/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/3/quantization.png b/fn_gen/rnd_noisy_scale/3/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..c41a0ca31b8580fd3b8baa918c6c080fbbe1fbae Binary files /dev/null and b/fn_gen/rnd_noisy_scale/3/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/4/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/4/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cacd225f3c0e25451e57efd478c7ede1e0ce966 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/4/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/4/distortion.png b/fn_gen/rnd_noisy_scale/4/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..d014489f0028c5c581fc021777fa3389405b2f6b Binary files /dev/null and b/fn_gen/rnd_noisy_scale/4/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/4/expressions.txt b/fn_gen/rnd_noisy_scale/4/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a7e5be4566beeb4727d82f95d24241966d158dc --- /dev/null +++ b/fn_gen/rnd_noisy_scale/4/expressions.txt @@ -0,0 +1,2 @@ +log(_0*x)/_s +exp(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/4/fn.py b/fn_gen/rnd_noisy_scale/4/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..dda553ba161dfb60747df37ca1347c81759f8346 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/4/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/4/loss.png b/fn_gen/rnd_noisy_scale/4/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..815f559dff7e67fb85f7d4c9c8563fe1480324fc Binary files /dev/null and b/fn_gen/rnd_noisy_scale/4/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/4/quantization.png b/fn_gen/rnd_noisy_scale/4/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..d3345f17d4aed0f765f78517168989b9b56d9117 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/4/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/5/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/5/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4ad83fcb4b3da98a5f3c5def1b76421b807f5d Binary files /dev/null and b/fn_gen/rnd_noisy_scale/5/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/5/distortion.png b/fn_gen/rnd_noisy_scale/5/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..3493b222f508dbf04c28db5d07376d2f63173544 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/5/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/5/expressions.txt b/fn_gen/rnd_noisy_scale/5/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa32b575e8c654dbc457c94f36222e70d86dc940 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/5/expressions.txt @@ -0,0 +1,2 @@ +atan(_0*x)/_s +tan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/5/fn.py b/fn_gen/rnd_noisy_scale/5/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..42abf1d4f45a819b77e1e77c8e7a8bb0615913fe --- /dev/null +++ b/fn_gen/rnd_noisy_scale/5/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/5/loss.png b/fn_gen/rnd_noisy_scale/5/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..de8bb1fcf9241ec566207e1c7ebce8d3f43d4397 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/5/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/5/quantization.png b/fn_gen/rnd_noisy_scale/5/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..c9c23e6df702b827f0690e523a513ffa709e011f Binary files /dev/null and b/fn_gen/rnd_noisy_scale/5/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/6/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/6/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9514947708c9b5a52fb829af0d5d41e50722c3cd Binary files /dev/null and b/fn_gen/rnd_noisy_scale/6/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/6/distortion.png b/fn_gen/rnd_noisy_scale/6/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..2d1006ca149227a6c21c94db7e25d9e6a9dfc99a Binary files /dev/null and b/fn_gen/rnd_noisy_scale/6/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/6/expressions.txt b/fn_gen/rnd_noisy_scale/6/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..23606e9f370f2e4adb43ed623c49d7fcaabd7355 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/6/expressions.txt @@ -0,0 +1,2 @@ +tan(_0*x)/_s +atan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/6/fn.py b/fn_gen/rnd_noisy_scale/6/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..30d6b98c51c7970594e5ab50ccd7c6acd792cb2c --- /dev/null +++ b/fn_gen/rnd_noisy_scale/6/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/6/loss.png b/fn_gen/rnd_noisy_scale/6/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..b3ced063b0beaf5c448b63dc1e12e729642be963 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/6/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/6/quantization.png b/fn_gen/rnd_noisy_scale/6/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..7f603eb9b8b26a0502ae75d5e74e558db6b33cd3 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/6/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/7/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/7/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e6aa003ad406377dd4aac242e6176eabb02f32 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/7/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/7/distortion.png b/fn_gen/rnd_noisy_scale/7/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..5d0903fb916bb9a1e80acd6dd395e40a5562d1e4 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/7/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/7/expressions.txt b/fn_gen/rnd_noisy_scale/7/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d6553d091cd1d343d7aa9b52b85ef6ec88ea854 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/7/expressions.txt @@ -0,0 +1,2 @@ +x/_s +_s*x \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/7/fn.py b/fn_gen/rnd_noisy_scale/7/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e9750eabeb36caefecea6401df6ce9a0f064da --- /dev/null +++ b/fn_gen/rnd_noisy_scale/7/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000))) + + +def dequantization(x, **params): + return (params['_s'] * x) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (x * np.divide(1, np_replace_num(_s, num=0, to=10000))) + + +def np_dequantization(x, _s): + return (_s * x) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/7/loss.png b/fn_gen/rnd_noisy_scale/7/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..6d94839911e639a0ef26aaee019277b4ab973aca Binary files /dev/null and b/fn_gen/rnd_noisy_scale/7/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/7/quantization.png b/fn_gen/rnd_noisy_scale/7/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..3e73b05ec4d079ed81e8513a101e730cf09f21b8 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/7/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/8/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/8/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f3a15219d5f9361b3eed709c29187f902330ba2 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/8/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/8/distortion.png b/fn_gen/rnd_noisy_scale/8/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..54b9826a752d1bc37c4b1eac3c108b224a6f9f57 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/8/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/8/expressions.txt b/fn_gen/rnd_noisy_scale/8/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8458af52eb4cfce21cf8459f3c454003cd78158 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/8/expressions.txt @@ -0,0 +1,2 @@ +sqrt(_0*x)/_s +_s**2*x**2/_0 \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/8/fn.py b/fn_gen/rnd_noisy_scale/8/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..85e967979eb053dd29556a48990a5139c529a638 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/8/fn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2))) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/8/loss.png b/fn_gen/rnd_noisy_scale/8/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..f82173586811255b76197b71caa5a1215543a5f7 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/8/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/8/quantization.png b/fn_gen/rnd_noisy_scale/8/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..cd3cdd059c697783e79799b19ed72cff1f4309d6 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/8/quantization.png differ diff --git a/fn_gen/rnd_noisy_scale/9/__pycache__/fn.cpython-310.pyc b/fn_gen/rnd_noisy_scale/9/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..179d2e73f8552ada6be4f432171b03db584e954e Binary files /dev/null and b/fn_gen/rnd_noisy_scale/9/__pycache__/fn.cpython-310.pyc differ diff --git a/fn_gen/rnd_noisy_scale/9/distortion.png b/fn_gen/rnd_noisy_scale/9/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a2ee371b1e8485c42a4c71d82570d92f84b5ccd0 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/9/distortion.png differ diff --git a/fn_gen/rnd_noisy_scale/9/expressions.txt b/fn_gen/rnd_noisy_scale/9/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..dbb6da0fc54c6f23dc12daf2e2c3a395819e1bf4 --- /dev/null +++ b/fn_gen/rnd_noisy_scale/9/expressions.txt @@ -0,0 +1,2 @@ +x**2/_s +sqrt(_s*x) \ No newline at end of file diff --git a/fn_gen/rnd_noisy_scale/9/fn.py b/fn_gen/rnd_noisy_scale/9/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b95b879e410733d747572481e3cbae6ea3b3645a --- /dev/null +++ b/fn_gen/rnd_noisy_scale/9/fn.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2))) + + +def dequantization(x, **params): + return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1)) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2))) + + +def np_dequantization(x, _s): + return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1)) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + # Checks if the best parameter is better than the init_ones + p_ones = init_ones(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) + loss_ones = nn.MSELoss()(x, x_) + + # Checks if the best parameter is better than the init_rand + p_rand = init_rand(x, **kwargs) + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) + loss_rand = nn.MSELoss()(x, x_) + + if loss_rand < best_params[0][1] and loss_rand < loss_ones: + return p_rand + elif loss_ones < best_params[0][1] and loss_ones < loss_rand: + return p_ones + else: + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) + loss_fn = nn.MSELoss() + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_noisy_scale/9/loss.png b/fn_gen/rnd_noisy_scale/9/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..8fdad74e2230e5a09e378804edffb67a3e0bae7d Binary files /dev/null and b/fn_gen/rnd_noisy_scale/9/loss.png differ diff --git a/fn_gen/rnd_noisy_scale/9/quantization.png b/fn_gen/rnd_noisy_scale/9/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..db6db15f99225d868b87966f04c888b9d0325b20 Binary files /dev/null and b/fn_gen/rnd_noisy_scale/9/quantization.png differ