Upload learned functions
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fn_gen/ones_affine_scale/0/distortion.png +0 -0
- fn_gen/ones_affine_scale/0/expressions.txt +2 -0
- fn_gen/ones_affine_scale/0/fn.py +493 -0
- fn_gen/ones_affine_scale/0/loss.png +0 -0
- fn_gen/ones_affine_scale/0/quantization.png +0 -0
- fn_gen/ones_affine_scale/1/distortion.png +0 -0
- fn_gen/ones_affine_scale/1/expressions.txt +2 -0
- fn_gen/ones_affine_scale/1/fn.py +493 -0
- fn_gen/ones_affine_scale/1/loss.png +0 -0
- fn_gen/ones_affine_scale/1/quantization.png +0 -0
- fn_gen/ones_affine_scale/10/distortion.png +0 -0
- fn_gen/ones_affine_scale/10/expressions.txt +2 -0
- fn_gen/ones_affine_scale/10/fn.py +492 -0
- fn_gen/ones_affine_scale/10/loss.png +0 -0
- fn_gen/ones_affine_scale/10/quantization.png +0 -0
- fn_gen/ones_affine_scale/11/distortion.png +0 -0
- fn_gen/ones_affine_scale/11/expressions.txt +2 -0
- fn_gen/ones_affine_scale/11/fn.py +493 -0
- fn_gen/ones_affine_scale/11/loss.png +0 -0
- fn_gen/ones_affine_scale/11/quantization.png +0 -0
- fn_gen/ones_affine_scale/12/distortion.png +0 -0
- fn_gen/ones_affine_scale/12/expressions.txt +2 -0
- fn_gen/ones_affine_scale/12/fn.py +493 -0
- fn_gen/ones_affine_scale/12/loss.png +0 -0
- fn_gen/ones_affine_scale/12/quantization.png +0 -0
- fn_gen/ones_affine_scale/13/distortion.png +0 -0
- fn_gen/ones_affine_scale/13/expressions.txt +2 -0
- fn_gen/ones_affine_scale/13/fn.py +493 -0
- fn_gen/ones_affine_scale/13/loss.png +0 -0
- fn_gen/ones_affine_scale/13/quantization.png +0 -0
- fn_gen/ones_affine_scale/14/distortion.png +0 -0
- fn_gen/ones_affine_scale/14/expressions.txt +2 -0
- fn_gen/ones_affine_scale/14/fn.py +493 -0
- fn_gen/ones_affine_scale/14/loss.png +0 -0
- fn_gen/ones_affine_scale/14/quantization.png +0 -0
- fn_gen/ones_affine_scale/15/distortion.png +0 -0
- fn_gen/ones_affine_scale/15/expressions.txt +2 -0
- fn_gen/ones_affine_scale/15/fn.py +493 -0
- fn_gen/ones_affine_scale/15/loss.png +0 -0
- fn_gen/ones_affine_scale/15/quantization.png +0 -0
- fn_gen/ones_affine_scale/16/distortion.png +0 -0
- fn_gen/ones_affine_scale/16/expressions.txt +2 -0
- fn_gen/ones_affine_scale/16/fn.py +493 -0
- fn_gen/ones_affine_scale/16/loss.png +0 -0
- fn_gen/ones_affine_scale/16/quantization.png +0 -0
- fn_gen/ones_affine_scale/17/distortion.png +0 -0
- fn_gen/ones_affine_scale/17/expressions.txt +2 -0
- fn_gen/ones_affine_scale/17/fn.py +493 -0
- fn_gen/ones_affine_scale/17/loss.png +0 -0
- fn_gen/ones_affine_scale/17/quantization.png +0 -0
fn_gen/ones_affine_scale/0/distortion.png
ADDED
fn_gen/ones_affine_scale/0/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
atanh(_0*x)/_s
|
2 |
+
tanh(_s*x)/_0
|
fn_gen/ones_affine_scale/0/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/0/loss.png
ADDED
fn_gen/ones_affine_scale/0/quantization.png
ADDED
fn_gen/ones_affine_scale/1/distortion.png
ADDED
fn_gen/ones_affine_scale/1/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
(_0*x)**(1/3)/_s
|
2 |
+
_s**3*x**3/_0
|
fn_gen/ones_affine_scale/1/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/1/loss.png
ADDED
fn_gen/ones_affine_scale/1/quantization.png
ADDED
fn_gen/ones_affine_scale/10/distortion.png
ADDED
fn_gen/ones_affine_scale/10/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
x**2/_s
|
2 |
+
sqrt(_s*x)
|
fn_gen/ones_affine_scale/10/fn.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
}
|
60 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
61 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
62 |
+
|
63 |
+
if 'post_init_hook' in kwargs:
|
64 |
+
kwargs['post_init_hook'](parameters=params)
|
65 |
+
|
66 |
+
|
67 |
+
if 'post_train_hook' in kwargs:
|
68 |
+
kwargs['post_train_hook'](parameters=params)
|
69 |
+
|
70 |
+
return params
|
71 |
+
|
72 |
+
|
73 |
+
############### Numpy Qtz ###############
|
74 |
+
|
75 |
+
|
76 |
+
def np_quantization(x, _s):
|
77 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
|
78 |
+
|
79 |
+
|
80 |
+
def np_dequantization(x, _s):
|
81 |
+
return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
|
82 |
+
|
83 |
+
|
84 |
+
def fit_func(x, _s):
|
85 |
+
x_ = np_quantization(x, _s)
|
86 |
+
x_ = np_dequantization(x_, _s)
|
87 |
+
return x_
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
############### HELPERS ###############
|
92 |
+
|
93 |
+
def domain_guard(
|
94 |
+
x: torch.Tensor,
|
95 |
+
min: float = None,
|
96 |
+
max: float = None,
|
97 |
+
posinf: float = None,
|
98 |
+
neginf: float = None,
|
99 |
+
nan: float = None
|
100 |
+
) -> torch.Tensor:
|
101 |
+
"""Guard a tensor to a valid domain."""
|
102 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
103 |
+
if min is not None or max is not None:
|
104 |
+
x = torch.clamp(x, min=min, max=max)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
109 |
+
"""Replace a number in a tensor with another number.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (torch.Tensor): The input tensor.
|
113 |
+
num (float): The number to replace.
|
114 |
+
to (float): The number to replace with.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
torch.Tensor: The tensor with the number replaced.
|
118 |
+
"""
|
119 |
+
return torch.where(x == num, to, x)
|
120 |
+
|
121 |
+
|
122 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
123 |
+
"""Guard the power operation to a valid domain."""
|
124 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
125 |
+
|
126 |
+
|
127 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
128 |
+
val = torch.amin(x, dim=1)
|
129 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
130 |
+
|
131 |
+
|
132 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
133 |
+
val = torch.amin(x, dim=1)
|
134 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
135 |
+
|
136 |
+
|
137 |
+
def init_space_search(
|
138 |
+
x: torch.Tensor,
|
139 |
+
**kwargs: Dict[str, Any],
|
140 |
+
) -> torch.Tensor:
|
141 |
+
|
142 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
143 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
144 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
145 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
146 |
+
|
147 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
148 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
149 |
+
torch_tensors = torch.stack(tensors)
|
150 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
151 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
152 |
+
mean = torch.mean(torch_tensors, dim=0)
|
153 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
154 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
155 |
+
|
156 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
157 |
+
x_ = x.transpose(0, 1)
|
158 |
+
x_ = qtz_func(x=x_, **params)
|
159 |
+
x_ = deqtz_func(x=x_, **params)
|
160 |
+
x_ = x_.transpose(0, 1)
|
161 |
+
return x_
|
162 |
+
|
163 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
164 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
165 |
+
assert "params_list" in kwargs, "params list must be provided."
|
166 |
+
assert "param" in kwargs, "param must be provided."
|
167 |
+
|
168 |
+
qtz_func = kwargs.get('qtz_func')
|
169 |
+
deqtz_func = kwargs.get('deqtz_func')
|
170 |
+
params_list = kwargs.get('params_list')
|
171 |
+
param = kwargs.get('param')
|
172 |
+
|
173 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
174 |
+
n_random_params = 50 # Number of random parameters to generate
|
175 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
176 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
177 |
+
|
178 |
+
# Initializes the parameters
|
179 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
180 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
181 |
+
|
182 |
+
# Performs the search
|
183 |
+
for _ in range(n_runs):
|
184 |
+
|
185 |
+
best_params = []
|
186 |
+
for param_ in params:
|
187 |
+
try:
|
188 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
189 |
+
loss_ones = nn.MSELoss()(x, x_)
|
190 |
+
|
191 |
+
if len(best_params) < n_best_to_pick:
|
192 |
+
best_params.append((param_, loss_ones.item()))
|
193 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
194 |
+
elif loss_ones < best_params[-1][1]:
|
195 |
+
best_params[-1] = (param_, loss_ones.item())
|
196 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
197 |
+
|
198 |
+
except Exception: # The parameters might not be valid for the function's domain
|
199 |
+
continue
|
200 |
+
|
201 |
+
# Generates new parameters around the mean
|
202 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
203 |
+
|
204 |
+
# Checks if the best parameter is better than the init_ones
|
205 |
+
p_ones = init_ones(x, **kwargs)
|
206 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
207 |
+
loss_ones = nn.MSELoss()(x, x_)
|
208 |
+
|
209 |
+
# Checks if the best parameter is better than the init_rand
|
210 |
+
p_rand = init_rand(x, **kwargs)
|
211 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
212 |
+
loss_rand = nn.MSELoss()(x, x_)
|
213 |
+
|
214 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
215 |
+
return p_rand
|
216 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
217 |
+
return p_ones
|
218 |
+
else:
|
219 |
+
return best_params[0][0]
|
220 |
+
|
221 |
+
|
222 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
223 |
+
x: torch.Tensor,
|
224 |
+
**kwargs: Dict[str, Any],
|
225 |
+
) -> torch.Tensor:
|
226 |
+
assert "bits" in kwargs, "bits must be provided."
|
227 |
+
assert "params" in kwargs, "params must be provided."
|
228 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
229 |
+
|
230 |
+
bits = kwargs.get('bits')
|
231 |
+
params = kwargs.get('params')
|
232 |
+
qtz_func = kwargs.get('qtz_func')
|
233 |
+
|
234 |
+
x_ = x.transpose(0, 1)
|
235 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
236 |
+
x_ = x_.transpose(0, 1)
|
237 |
+
|
238 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
239 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
240 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
241 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
242 |
+
|
243 |
+
eps = torch.finfo(torch.float32).eps
|
244 |
+
|
245 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
246 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
247 |
+
|
248 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
249 |
+
|
250 |
+
return scale
|
251 |
+
|
252 |
+
# Introduces some noise in scale
|
253 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
254 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
255 |
+
return scale
|
256 |
+
|
257 |
+
|
258 |
+
def init_non_linear_regression_fit(
|
259 |
+
x: torch.Tensor,
|
260 |
+
**kwargs: Dict[str, Any],
|
261 |
+
) -> torch.Tensor:
|
262 |
+
|
263 |
+
assert "params_list" in kwargs, "params list must be provided."
|
264 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
265 |
+
assert "p0" in kwargs, "p0 must be provided."
|
266 |
+
np_fit_func = kwargs.get('np_fit_func')
|
267 |
+
params_list = kwargs.get('params_list')
|
268 |
+
p0 = kwargs.get('p0')
|
269 |
+
|
270 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
271 |
+
popt, _ = curve_fit(
|
272 |
+
func,
|
273 |
+
xdata,
|
274 |
+
ydata,
|
275 |
+
maxfev=1000,
|
276 |
+
p0=p0,
|
277 |
+
method='lm'
|
278 |
+
)
|
279 |
+
return popt
|
280 |
+
|
281 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
282 |
+
xdata = x.cpu().numpy()
|
283 |
+
|
284 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
285 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
286 |
+
|
287 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
288 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
289 |
+
|
290 |
+
# 3. Finds the best parameters for each channel
|
291 |
+
try:
|
292 |
+
params = []
|
293 |
+
for i in range(sorted_xdata.shape[0]):
|
294 |
+
xdata_ = sorted_xdata[i]
|
295 |
+
p0_ = [p0[p][i] for p in params_list]
|
296 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
297 |
+
params.append(ch_params)
|
298 |
+
|
299 |
+
# 4. Builds the parameters
|
300 |
+
result = {}
|
301 |
+
for i, p in enumerate(params_list):
|
302 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
303 |
+
|
304 |
+
return result
|
305 |
+
|
306 |
+
except ValueError as e:
|
307 |
+
print(f"Could not fit the function with error: {e}")
|
308 |
+
print(f"Using fallback result...")
|
309 |
+
return {
|
310 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
311 |
+
}
|
312 |
+
|
313 |
+
|
314 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
315 |
+
val = torch.amin(x, dim=1)
|
316 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
317 |
+
|
318 |
+
|
319 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
320 |
+
# Calculate the original minimum and maximum values
|
321 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
322 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
323 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
324 |
+
|
325 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
326 |
+
return torch.ones_like(x_min)
|
327 |
+
|
328 |
+
# Calculate the scale factor
|
329 |
+
scale = (_max - _min) / (x_max - x_min)
|
330 |
+
return scale
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
############## Quant ###############
|
335 |
+
|
336 |
+
@torch.enable_grad()
|
337 |
+
def learn_parameters(
|
338 |
+
x: torch.Tensor,
|
339 |
+
params: Dict[str, nn.Parameter],
|
340 |
+
qtz_func: nn.Module,
|
341 |
+
deqtz_func: nn.Module,
|
342 |
+
bits: int,
|
343 |
+
target_dtype: torch.dtype,
|
344 |
+
epochs: int = 1000,
|
345 |
+
early_stop: bool = True,
|
346 |
+
do_report: bool = False
|
347 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
348 |
+
|
349 |
+
# Requires gradients in the parameters
|
350 |
+
for p in params.values():
|
351 |
+
p.requires_grad = True
|
352 |
+
p.grad = None
|
353 |
+
|
354 |
+
param_keys = list(params.keys())
|
355 |
+
param_values = list(params.values())
|
356 |
+
|
357 |
+
# Defines optimizer and loss function
|
358 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
359 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
360 |
+
loss_fn = nn.MSELoss()
|
361 |
+
|
362 |
+
# Contains the best loss and the best parameters
|
363 |
+
best_loss = float("inf")
|
364 |
+
best_params = None
|
365 |
+
|
366 |
+
# Used to stop the search early
|
367 |
+
min_delta = 1e-7
|
368 |
+
acc_loss = []
|
369 |
+
percent_epochs_before_stop = 0.1
|
370 |
+
|
371 |
+
for i in range(epochs):
|
372 |
+
optimizer.zero_grad()
|
373 |
+
|
374 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
375 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
376 |
+
loss = loss_fn(x, dequant)
|
377 |
+
|
378 |
+
if loss.isnan() or loss.isinf():
|
379 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
380 |
+
|
381 |
+
loss.backward()
|
382 |
+
optimizer.step()
|
383 |
+
scheduler.step()
|
384 |
+
|
385 |
+
acc_loss.append(loss.item())
|
386 |
+
|
387 |
+
# Reports loss every 10 steps
|
388 |
+
if i % 10 == 0 and do_report:
|
389 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
390 |
+
|
391 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
392 |
+
if loss.item() < best_loss:
|
393 |
+
best_loss = loss.item()
|
394 |
+
best_params = copy.deepcopy({
|
395 |
+
k: v for k, v in params.items() if k in param_keys
|
396 |
+
})
|
397 |
+
|
398 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
399 |
+
if early_stop:
|
400 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
401 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
402 |
+
break
|
403 |
+
|
404 |
+
# No longer requires gradients in the parameters
|
405 |
+
for p in best_params.values():
|
406 |
+
p.requires_grad = False
|
407 |
+
p.grad = None
|
408 |
+
|
409 |
+
if do_report:
|
410 |
+
return best_params, acc_loss
|
411 |
+
else:
|
412 |
+
return best_params
|
413 |
+
|
414 |
+
|
415 |
+
def quantize(
|
416 |
+
x: torch.Tensor,
|
417 |
+
params: Dict[str, nn.Parameter],
|
418 |
+
func: nn.Module,
|
419 |
+
bits: int,
|
420 |
+
target_dtype: torch.dtype = torch.int8
|
421 |
+
) -> torch.Tensor:
|
422 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
423 |
+
x = x.transpose(0, 1) # Aligns shapes
|
424 |
+
x = func(x=x, **params)
|
425 |
+
x = x.transpose(0, 1)
|
426 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
def dequantize(
|
431 |
+
x: torch.Tensor,
|
432 |
+
params: Dict[str, nn.Parameter],
|
433 |
+
func: nn.Module,
|
434 |
+
bits: int,
|
435 |
+
out_dtype: torch.dtype
|
436 |
+
) -> torch.Tensor:
|
437 |
+
x = x.to(dtype=out_dtype)
|
438 |
+
x = x.transpose(0, 1)
|
439 |
+
x = func(x=x, **params)
|
440 |
+
x = x.transpose(0, 1)
|
441 |
+
return x
|
442 |
+
|
443 |
+
|
444 |
+
def round_func_BPDA(input):
|
445 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
446 |
+
# an identity function (differentiable) only when backward.
|
447 |
+
forward_value = torch.round(input)
|
448 |
+
out = input.clone()
|
449 |
+
out.data = forward_value.data
|
450 |
+
return out
|
451 |
+
|
452 |
+
|
453 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
454 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
455 |
+
|
456 |
+
|
457 |
+
|
458 |
+
############## Numpy ###############
|
459 |
+
|
460 |
+
def np_domain_guard(
|
461 |
+
x: np.ndarray,
|
462 |
+
min: float = None,
|
463 |
+
max: float = None,
|
464 |
+
posinf: float = None,
|
465 |
+
neginf: float = None,
|
466 |
+
nan: float = None
|
467 |
+
) -> np.ndarray:
|
468 |
+
"""Guard a tensor to a valid domain."""
|
469 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
470 |
+
if min is not None or max is not None:
|
471 |
+
x = np.clip(x, min, max)
|
472 |
+
return x
|
473 |
+
|
474 |
+
|
475 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
476 |
+
"""Replace a number in a tensor with another number.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
x (np.ndarray): The input tensor.
|
480 |
+
num (float): The number to replace.
|
481 |
+
to (float): The number to replace with.
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
np.ndarray: The tensor with the number replaced.
|
485 |
+
"""
|
486 |
+
return np.where(x == num, to, x)
|
487 |
+
|
488 |
+
|
489 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
490 |
+
"""Guard the power operation to a valid domain."""
|
491 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
492 |
+
|
fn_gen/ones_affine_scale/10/loss.png
ADDED
fn_gen/ones_affine_scale/10/quantization.png
ADDED
fn_gen/ones_affine_scale/11/distortion.png
ADDED
fn_gen/ones_affine_scale/11/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
asinh(_0*x)/_s
|
2 |
+
sinh(_s*x)/_0
|
fn_gen/ones_affine_scale/11/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/11/loss.png
ADDED
fn_gen/ones_affine_scale/11/quantization.png
ADDED
fn_gen/ones_affine_scale/12/distortion.png
ADDED
fn_gen/ones_affine_scale/12/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
acos(_0*x)/_s
|
2 |
+
cos(_s*x)/_0
|
fn_gen/ones_affine_scale/12/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/12/loss.png
ADDED
fn_gen/ones_affine_scale/12/quantization.png
ADDED
fn_gen/ones_affine_scale/13/distortion.png
ADDED
fn_gen/ones_affine_scale/13/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
acosh(_0*x)/_s
|
2 |
+
cosh(_s*x)/_0
|
fn_gen/ones_affine_scale/13/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/13/loss.png
ADDED
fn_gen/ones_affine_scale/13/quantization.png
ADDED
fn_gen/ones_affine_scale/14/distortion.png
ADDED
fn_gen/ones_affine_scale/14/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
tan(_0*x)/_s
|
2 |
+
atan(_s*x)/_0
|
fn_gen/ones_affine_scale/14/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/14/loss.png
ADDED
fn_gen/ones_affine_scale/14/quantization.png
ADDED
fn_gen/ones_affine_scale/15/distortion.png
ADDED
fn_gen/ones_affine_scale/15/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
sinh(_0*x)/_s
|
2 |
+
log(_s*x - sqrt(_s**2*x**2 + 1))/_0
|
fn_gen/ones_affine_scale/15/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/15/loss.png
ADDED
fn_gen/ones_affine_scale/15/quantization.png
ADDED
fn_gen/ones_affine_scale/16/distortion.png
ADDED
fn_gen/ones_affine_scale/16/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
atan(_0*x)/_s
|
2 |
+
tan(_s*x)/_0
|
fn_gen/ones_affine_scale/16/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/16/loss.png
ADDED
fn_gen/ones_affine_scale/16/quantization.png
ADDED
fn_gen/ones_affine_scale/17/distortion.png
ADDED
fn_gen/ones_affine_scale/17/expressions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
tanh(_0*x)/_s
|
2 |
+
log((-_s*x - 1)/(_s*x - 1))/_0
|
fn_gen/ones_affine_scale/17/fn.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import amin # Necessary for arcsin
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from scipy.optimize import curve_fit
|
10 |
+
from typing import Dict, Any, Tuple, List, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def quantization(x, **params):
|
14 |
+
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
|
15 |
+
|
16 |
+
|
17 |
+
def dequantization(x, **params):
|
18 |
+
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
|
19 |
+
|
20 |
+
|
21 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
22 |
+
x: torch.Tensor,
|
23 |
+
**kwargs: Dict[str, Any],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
assert "bits" in kwargs, "bits must be provided."
|
26 |
+
assert "params" in kwargs, "params must be provided."
|
27 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
28 |
+
|
29 |
+
bits = kwargs.get('bits')
|
30 |
+
params = kwargs.get('params')
|
31 |
+
qtz_func = kwargs.get('qtz_func')
|
32 |
+
|
33 |
+
x_ = x.transpose(0, 1)
|
34 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
35 |
+
x_ = x_.transpose(0, 1)
|
36 |
+
|
37 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
38 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
39 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
40 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
41 |
+
|
42 |
+
eps = torch.finfo(torch.float32).eps
|
43 |
+
|
44 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
45 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
46 |
+
|
47 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
48 |
+
|
49 |
+
return scale
|
50 |
+
|
51 |
+
# Introduces some noise in scale
|
52 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
53 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
54 |
+
return scale
|
55 |
+
|
56 |
+
|
57 |
+
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
|
58 |
+
params = {
|
59 |
+
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
|
60 |
+
}
|
61 |
+
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
|
62 |
+
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
|
63 |
+
|
64 |
+
if 'post_init_hook' in kwargs:
|
65 |
+
kwargs['post_init_hook'](parameters=params)
|
66 |
+
|
67 |
+
|
68 |
+
if 'post_train_hook' in kwargs:
|
69 |
+
kwargs['post_train_hook'](parameters=params)
|
70 |
+
|
71 |
+
return params
|
72 |
+
|
73 |
+
|
74 |
+
############### Numpy Qtz ###############
|
75 |
+
|
76 |
+
|
77 |
+
def np_quantization(x, _0, _s):
|
78 |
+
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
|
79 |
+
|
80 |
+
|
81 |
+
def np_dequantization(x, _0, _s):
|
82 |
+
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
|
83 |
+
|
84 |
+
|
85 |
+
def fit_func(x, _0, _s):
|
86 |
+
x_ = np_quantization(x, _0, _s)
|
87 |
+
x_ = np_dequantization(x_, _0, _s)
|
88 |
+
return x_
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
############### HELPERS ###############
|
93 |
+
|
94 |
+
def domain_guard(
|
95 |
+
x: torch.Tensor,
|
96 |
+
min: float = None,
|
97 |
+
max: float = None,
|
98 |
+
posinf: float = None,
|
99 |
+
neginf: float = None,
|
100 |
+
nan: float = None
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Guard a tensor to a valid domain."""
|
103 |
+
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
104 |
+
if min is not None or max is not None:
|
105 |
+
x = torch.clamp(x, min=min, max=max)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
|
110 |
+
"""Replace a number in a tensor with another number.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor.
|
114 |
+
num (float): The number to replace.
|
115 |
+
to (float): The number to replace with.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The tensor with the number replaced.
|
119 |
+
"""
|
120 |
+
return torch.where(x == num, to, x)
|
121 |
+
|
122 |
+
|
123 |
+
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
|
124 |
+
"""Guard the power operation to a valid domain."""
|
125 |
+
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
|
126 |
+
|
127 |
+
|
128 |
+
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
129 |
+
val = torch.amin(x, dim=1)
|
130 |
+
return torch.ones_like(val, dtype=torch.float32, device=x.device)
|
131 |
+
|
132 |
+
|
133 |
+
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
134 |
+
val = torch.amin(x, dim=1)
|
135 |
+
return torch.randn_like(val, dtype=torch.float32, device=x.device)
|
136 |
+
|
137 |
+
|
138 |
+
def init_space_search(
|
139 |
+
x: torch.Tensor,
|
140 |
+
**kwargs: Dict[str, Any],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
|
143 |
+
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
|
144 |
+
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
|
145 |
+
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
|
146 |
+
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
|
147 |
+
|
148 |
+
def _search_param(tensors: List[torch.tensor], n_params):
|
149 |
+
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
|
150 |
+
torch_tensors = torch.stack(tensors)
|
151 |
+
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
|
152 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
153 |
+
mean = torch.mean(torch_tensors, dim=0)
|
154 |
+
for _ in range(n_params): # Generates n_params around the mean of the tensors
|
155 |
+
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
|
156 |
+
|
157 |
+
def _calc(x, qtz_func, deqtz_func, **params):
|
158 |
+
x_ = x.transpose(0, 1)
|
159 |
+
x_ = qtz_func(x=x_, **params)
|
160 |
+
x_ = deqtz_func(x=x_, **params)
|
161 |
+
x_ = x_.transpose(0, 1)
|
162 |
+
return x_
|
163 |
+
|
164 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
165 |
+
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
|
166 |
+
assert "params_list" in kwargs, "params list must be provided."
|
167 |
+
assert "param" in kwargs, "param must be provided."
|
168 |
+
|
169 |
+
qtz_func = kwargs.get('qtz_func')
|
170 |
+
deqtz_func = kwargs.get('deqtz_func')
|
171 |
+
params_list = kwargs.get('params_list')
|
172 |
+
param = kwargs.get('param')
|
173 |
+
|
174 |
+
n_runs = 50 # Number of runs to try to find the best parameters
|
175 |
+
n_random_params = 50 # Number of random parameters to generate
|
176 |
+
n_best_to_pick = 5 # Number of best parameters to pick after each run
|
177 |
+
max_initial = 10000 # Maximum value to initialize the parameters
|
178 |
+
|
179 |
+
# Initializes the parameters
|
180 |
+
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
|
181 |
+
params = _build_initial_param(x, max_initial, n_random_params)
|
182 |
+
|
183 |
+
# Performs the search
|
184 |
+
for _ in range(n_runs):
|
185 |
+
|
186 |
+
best_params = []
|
187 |
+
for param_ in params:
|
188 |
+
try:
|
189 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
|
190 |
+
loss_ones = nn.MSELoss()(x, x_)
|
191 |
+
|
192 |
+
if len(best_params) < n_best_to_pick:
|
193 |
+
best_params.append((param_, loss_ones.item()))
|
194 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
195 |
+
elif loss_ones < best_params[-1][1]:
|
196 |
+
best_params[-1] = (param_, loss_ones.item())
|
197 |
+
best_params = sorted(best_params, key=lambda x: x[1])
|
198 |
+
|
199 |
+
except Exception: # The parameters might not be valid for the function's domain
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Generates new parameters around the mean
|
203 |
+
params = _search_param([p for p, _ in best_params], n_random_params)
|
204 |
+
|
205 |
+
# Checks if the best parameter is better than the init_ones
|
206 |
+
p_ones = init_ones(x, **kwargs)
|
207 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
|
208 |
+
loss_ones = nn.MSELoss()(x, x_)
|
209 |
+
|
210 |
+
# Checks if the best parameter is better than the init_rand
|
211 |
+
p_rand = init_rand(x, **kwargs)
|
212 |
+
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
|
213 |
+
loss_rand = nn.MSELoss()(x, x_)
|
214 |
+
|
215 |
+
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
|
216 |
+
return p_rand
|
217 |
+
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
|
218 |
+
return p_ones
|
219 |
+
else:
|
220 |
+
return best_params[0][0]
|
221 |
+
|
222 |
+
|
223 |
+
def init_linear_scale( # Symmetric scale. From the study folder
|
224 |
+
x: torch.Tensor,
|
225 |
+
**kwargs: Dict[str, Any],
|
226 |
+
) -> torch.Tensor:
|
227 |
+
assert "bits" in kwargs, "bits must be provided."
|
228 |
+
assert "params" in kwargs, "params must be provided."
|
229 |
+
assert "qtz_func" in kwargs, "qtz_func must be provided."
|
230 |
+
|
231 |
+
bits = kwargs.get('bits')
|
232 |
+
params = kwargs.get('params')
|
233 |
+
qtz_func = kwargs.get('qtz_func')
|
234 |
+
|
235 |
+
x_ = x.transpose(0, 1)
|
236 |
+
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
|
237 |
+
x_ = x_.transpose(0, 1)
|
238 |
+
|
239 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
240 |
+
min_vals, max_vals = torch.aminmax(x_, dim=1)
|
241 |
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
242 |
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
243 |
+
|
244 |
+
eps = torch.finfo(torch.float32).eps
|
245 |
+
|
246 |
+
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
|
247 |
+
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
|
248 |
+
|
249 |
+
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
|
250 |
+
|
251 |
+
return scale
|
252 |
+
|
253 |
+
# Introduces some noise in scale
|
254 |
+
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
|
255 |
+
scale = scale + 0.01 * torch.randn_like(scale)
|
256 |
+
return scale
|
257 |
+
|
258 |
+
|
259 |
+
def init_non_linear_regression_fit(
|
260 |
+
x: torch.Tensor,
|
261 |
+
**kwargs: Dict[str, Any],
|
262 |
+
) -> torch.Tensor:
|
263 |
+
|
264 |
+
assert "params_list" in kwargs, "params list must be provided."
|
265 |
+
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
|
266 |
+
assert "p0" in kwargs, "p0 must be provided."
|
267 |
+
np_fit_func = kwargs.get('np_fit_func')
|
268 |
+
params_list = kwargs.get('params_list')
|
269 |
+
p0 = kwargs.get('p0')
|
270 |
+
|
271 |
+
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
|
272 |
+
popt, _ = curve_fit(
|
273 |
+
func,
|
274 |
+
xdata,
|
275 |
+
ydata,
|
276 |
+
maxfev=1000,
|
277 |
+
p0=p0,
|
278 |
+
method='lm'
|
279 |
+
)
|
280 |
+
return popt
|
281 |
+
|
282 |
+
# 1. Needs to convert the torch tensor to numpy tensor
|
283 |
+
xdata = x.cpu().numpy()
|
284 |
+
|
285 |
+
# 2. Sorts the data so that it makes it easier to fit to it
|
286 |
+
sorted_xdata = np.sort(xdata, axis=-1)
|
287 |
+
|
288 |
+
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
|
289 |
+
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
|
290 |
+
|
291 |
+
# 3. Finds the best parameters for each channel
|
292 |
+
try:
|
293 |
+
params = []
|
294 |
+
for i in range(sorted_xdata.shape[0]):
|
295 |
+
xdata_ = sorted_xdata[i]
|
296 |
+
p0_ = [p0[p][i] for p in params_list]
|
297 |
+
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
|
298 |
+
params.append(ch_params)
|
299 |
+
|
300 |
+
# 4. Builds the parameters
|
301 |
+
result = {}
|
302 |
+
for i, p in enumerate(params_list):
|
303 |
+
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
|
304 |
+
|
305 |
+
return result
|
306 |
+
|
307 |
+
except ValueError as e:
|
308 |
+
print(f"Could not fit the function with error: {e}")
|
309 |
+
print(f"Using fallback result...")
|
310 |
+
return {
|
311 |
+
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
|
316 |
+
val = torch.amin(x, dim=1)
|
317 |
+
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
|
318 |
+
|
319 |
+
|
320 |
+
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
|
321 |
+
# Calculate the original minimum and maximum values
|
322 |
+
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
|
323 |
+
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
|
324 |
+
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
|
325 |
+
|
326 |
+
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
|
327 |
+
return torch.ones_like(x_min)
|
328 |
+
|
329 |
+
# Calculate the scale factor
|
330 |
+
scale = (_max - _min) / (x_max - x_min)
|
331 |
+
return scale
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
############## Quant ###############
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def learn_parameters(
|
339 |
+
x: torch.Tensor,
|
340 |
+
params: Dict[str, nn.Parameter],
|
341 |
+
qtz_func: nn.Module,
|
342 |
+
deqtz_func: nn.Module,
|
343 |
+
bits: int,
|
344 |
+
target_dtype: torch.dtype,
|
345 |
+
epochs: int = 1000,
|
346 |
+
early_stop: bool = True,
|
347 |
+
do_report: bool = False
|
348 |
+
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
|
349 |
+
|
350 |
+
# Requires gradients in the parameters
|
351 |
+
for p in params.values():
|
352 |
+
p.requires_grad = True
|
353 |
+
p.grad = None
|
354 |
+
|
355 |
+
param_keys = list(params.keys())
|
356 |
+
param_values = list(params.values())
|
357 |
+
|
358 |
+
# Defines optimizer and loss function
|
359 |
+
optimizer = torch.optim.Adam(param_values, lr=0.001)
|
360 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
|
361 |
+
loss_fn = nn.MSELoss()
|
362 |
+
|
363 |
+
# Contains the best loss and the best parameters
|
364 |
+
best_loss = float("inf")
|
365 |
+
best_params = None
|
366 |
+
|
367 |
+
# Used to stop the search early
|
368 |
+
min_delta = 1e-7
|
369 |
+
acc_loss = []
|
370 |
+
percent_epochs_before_stop = 0.1
|
371 |
+
|
372 |
+
for i in range(epochs):
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
quant = quantize(x, params, qtz_func, bits, target_dtype)
|
376 |
+
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
|
377 |
+
loss = loss_fn(x, dequant)
|
378 |
+
|
379 |
+
if loss.isnan() or loss.isinf():
|
380 |
+
raise Exception("Loss is NaN or Inf. Stopping the search.")
|
381 |
+
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
scheduler.step()
|
385 |
+
|
386 |
+
acc_loss.append(loss.item())
|
387 |
+
|
388 |
+
# Reports loss every 10 steps
|
389 |
+
if i % 10 == 0 and do_report:
|
390 |
+
print(f"Epoch {i}: Loss {loss.item()}")
|
391 |
+
|
392 |
+
# Optimizes the parameter search by storing the best loss and the parameters
|
393 |
+
if loss.item() < best_loss:
|
394 |
+
best_loss = loss.item()
|
395 |
+
best_params = copy.deepcopy({
|
396 |
+
k: v for k, v in params.items() if k in param_keys
|
397 |
+
})
|
398 |
+
|
399 |
+
# We also stop the search if the loss has not considerably during the last 10% epochs
|
400 |
+
if early_stop:
|
401 |
+
epochs_before_stop = int(epochs * percent_epochs_before_stop)
|
402 |
+
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
|
403 |
+
break
|
404 |
+
|
405 |
+
# No longer requires gradients in the parameters
|
406 |
+
for p in best_params.values():
|
407 |
+
p.requires_grad = False
|
408 |
+
p.grad = None
|
409 |
+
|
410 |
+
if do_report:
|
411 |
+
return best_params, acc_loss
|
412 |
+
else:
|
413 |
+
return best_params
|
414 |
+
|
415 |
+
|
416 |
+
def quantize(
|
417 |
+
x: torch.Tensor,
|
418 |
+
params: Dict[str, nn.Parameter],
|
419 |
+
func: nn.Module,
|
420 |
+
bits: int,
|
421 |
+
target_dtype: torch.dtype = torch.int8
|
422 |
+
) -> torch.Tensor:
|
423 |
+
quant_min, quant_max = get_min_max_from_bits_signed(bits)
|
424 |
+
x = x.transpose(0, 1) # Aligns shapes
|
425 |
+
x = func(x=x, **params)
|
426 |
+
x = x.transpose(0, 1)
|
427 |
+
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
def dequantize(
|
432 |
+
x: torch.Tensor,
|
433 |
+
params: Dict[str, nn.Parameter],
|
434 |
+
func: nn.Module,
|
435 |
+
bits: int,
|
436 |
+
out_dtype: torch.dtype
|
437 |
+
) -> torch.Tensor:
|
438 |
+
x = x.to(dtype=out_dtype)
|
439 |
+
x = x.transpose(0, 1)
|
440 |
+
x = func(x=x, **params)
|
441 |
+
x = x.transpose(0, 1)
|
442 |
+
return x
|
443 |
+
|
444 |
+
|
445 |
+
def round_func_BPDA(input):
|
446 |
+
# This is equivalent to replacing round function (non-differentiable) with
|
447 |
+
# an identity function (differentiable) only when backward.
|
448 |
+
forward_value = torch.round(input)
|
449 |
+
out = input.clone()
|
450 |
+
out.data = forward_value.data
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
|
455 |
+
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
############## Numpy ###############
|
460 |
+
|
461 |
+
def np_domain_guard(
|
462 |
+
x: np.ndarray,
|
463 |
+
min: float = None,
|
464 |
+
max: float = None,
|
465 |
+
posinf: float = None,
|
466 |
+
neginf: float = None,
|
467 |
+
nan: float = None
|
468 |
+
) -> np.ndarray:
|
469 |
+
"""Guard a tensor to a valid domain."""
|
470 |
+
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
|
471 |
+
if min is not None or max is not None:
|
472 |
+
x = np.clip(x, min, max)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
|
477 |
+
"""Replace a number in a tensor with another number.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
x (np.ndarray): The input tensor.
|
481 |
+
num (float): The number to replace.
|
482 |
+
to (float): The number to replace with.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
np.ndarray: The tensor with the number replaced.
|
486 |
+
"""
|
487 |
+
return np.where(x == num, to, x)
|
488 |
+
|
489 |
+
|
490 |
+
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
|
491 |
+
"""Guard the power operation to a valid domain."""
|
492 |
+
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|
493 |
+
|
fn_gen/ones_affine_scale/17/loss.png
ADDED
fn_gen/ones_affine_scale/17/quantization.png
ADDED