Diogo-V commited on
Commit
9d6529f
1 Parent(s): bb1a581

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/rnd_search_t_no_sched/0/distortion.png +0 -0
  2. fn_gen/rnd_search_t_no_sched/0/expressions.txt +2 -0
  3. fn_gen/rnd_search_t_no_sched/0/fn.py +598 -0
  4. fn_gen/rnd_search_t_no_sched/0/loss.png +0 -0
  5. fn_gen/rnd_search_t_no_sched/0/quantization.png +0 -0
  6. fn_gen/rnd_search_t_no_sched/10/distortion.png +0 -0
  7. fn_gen/rnd_search_t_no_sched/10/expressions.txt +2 -0
  8. fn_gen/rnd_search_t_no_sched/10/fn.py +598 -0
  9. fn_gen/rnd_search_t_no_sched/10/loss.png +0 -0
  10. fn_gen/rnd_search_t_no_sched/10/quantization.png +0 -0
  11. fn_gen/rnd_search_t_no_sched/11/distortion.png +0 -0
  12. fn_gen/rnd_search_t_no_sched/11/expressions.txt +2 -0
  13. fn_gen/rnd_search_t_no_sched/11/fn.py +598 -0
  14. fn_gen/rnd_search_t_no_sched/11/loss.png +0 -0
  15. fn_gen/rnd_search_t_no_sched/11/quantization.png +0 -0
  16. fn_gen/rnd_search_t_no_sched/12/distortion.png +0 -0
  17. fn_gen/rnd_search_t_no_sched/12/expressions.txt +2 -0
  18. fn_gen/rnd_search_t_no_sched/12/fn.py +598 -0
  19. fn_gen/rnd_search_t_no_sched/12/loss.png +0 -0
  20. fn_gen/rnd_search_t_no_sched/12/quantization.png +0 -0
  21. fn_gen/rnd_search_t_no_sched/13/distortion.png +0 -0
  22. fn_gen/rnd_search_t_no_sched/13/expressions.txt +2 -0
  23. fn_gen/rnd_search_t_no_sched/13/fn.py +598 -0
  24. fn_gen/rnd_search_t_no_sched/13/loss.png +0 -0
  25. fn_gen/rnd_search_t_no_sched/13/quantization.png +0 -0
  26. fn_gen/rnd_search_t_no_sched/14/distortion.png +0 -0
  27. fn_gen/rnd_search_t_no_sched/14/expressions.txt +2 -0
  28. fn_gen/rnd_search_t_no_sched/14/fn.py +598 -0
  29. fn_gen/rnd_search_t_no_sched/14/loss.png +0 -0
  30. fn_gen/rnd_search_t_no_sched/14/quantization.png +0 -0
  31. fn_gen/rnd_search_t_no_sched/15/distortion.png +0 -0
  32. fn_gen/rnd_search_t_no_sched/15/expressions.txt +2 -0
  33. fn_gen/rnd_search_t_no_sched/15/fn.py +598 -0
  34. fn_gen/rnd_search_t_no_sched/15/loss.png +0 -0
  35. fn_gen/rnd_search_t_no_sched/15/quantization.png +0 -0
  36. fn_gen/rnd_search_t_no_sched/16/distortion.png +0 -0
  37. fn_gen/rnd_search_t_no_sched/16/expressions.txt +2 -0
  38. fn_gen/rnd_search_t_no_sched/16/fn.py +598 -0
  39. fn_gen/rnd_search_t_no_sched/16/loss.png +0 -0
  40. fn_gen/rnd_search_t_no_sched/16/quantization.png +0 -0
  41. fn_gen/rnd_search_t_no_sched/18/distortion.png +0 -0
  42. fn_gen/rnd_search_t_no_sched/18/expressions.txt +2 -0
  43. fn_gen/rnd_search_t_no_sched/18/fn.py +512 -0
  44. fn_gen/rnd_search_t_no_sched/18/loss.png +0 -0
  45. fn_gen/rnd_search_t_no_sched/18/quantization.png +0 -0
  46. fn_gen/rnd_search_t_no_sched/2/distortion.png +0 -0
  47. fn_gen/rnd_search_t_no_sched/2/expressions.txt +2 -0
  48. fn_gen/rnd_search_t_no_sched/2/fn.py +512 -0
  49. fn_gen/rnd_search_t_no_sched/2/loss.png +0 -0
  50. fn_gen/rnd_search_t_no_sched/2/quantization.png +0 -0
fn_gen/rnd_search_t_no_sched/0/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tanh(_0*x)/_s
2
+ log((-_s*x - 1)/(_s*x - 1))/_0
fn_gen/rnd_search_t_no_sched/0/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/0/loss.png ADDED
fn_gen/rnd_search_t_no_sched/0/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/10/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/rnd_search_t_no_sched/10/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/10/loss.png ADDED
fn_gen/rnd_search_t_no_sched/10/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/11/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asin(_0*x)/_s
2
+ sin(_s*x)/_0
fn_gen/rnd_search_t_no_sched/11/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/11/loss.png ADDED
fn_gen/rnd_search_t_no_sched/11/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/12/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/rnd_search_t_no_sched/12/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/12/loss.png ADDED
fn_gen/rnd_search_t_no_sched/12/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/13/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/rnd_search_t_no_sched/13/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/13/loss.png ADDED
fn_gen/rnd_search_t_no_sched/13/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/14/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/rnd_search_t_no_sched/14/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/14/loss.png ADDED
fn_gen/rnd_search_t_no_sched/14/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/15/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ log(_0*x)/_s
2
+ exp(_s*x)/_0
fn_gen/rnd_search_t_no_sched/15/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/15/loss.png ADDED
fn_gen/rnd_search_t_no_sched/15/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/16/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ (_0*x)**(1/3)/_s
2
+ _s**3*x**3/_0
fn_gen/rnd_search_t_no_sched/16/fn.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+
468
+ # Contains the best loss and the best parameters
469
+ best_loss = float("inf")
470
+ best_params = None
471
+
472
+ # Used to stop the search early
473
+ min_delta = 1e-7
474
+ acc_loss = []
475
+ percent_epochs_before_stop = 0.1
476
+
477
+ for i in range(epochs):
478
+ optimizer.zero_grad()
479
+
480
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
481
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
482
+ loss = loss_fn(x, dequant)
483
+
484
+ if loss.isnan() or loss.isinf():
485
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+
490
+ acc_loss.append(loss.item())
491
+
492
+ # Reports loss every 10 steps
493
+ if i % 10 == 0 and do_report:
494
+ print(f"Epoch {i}: Loss {loss.item()}")
495
+
496
+ # Optimizes the parameter search by storing the best loss and the parameters
497
+ if loss.item() < best_loss:
498
+ best_loss = loss.item()
499
+ best_params = copy.deepcopy({
500
+ k: v for k, v in params.items() if k in param_keys
501
+ })
502
+
503
+ # We also stop the search if the loss has not considerably during the last 10% epochs
504
+ if early_stop:
505
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
506
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
507
+ break
508
+
509
+ # No longer requires gradients in the parameters
510
+ for p in best_params.values():
511
+ p.requires_grad = False
512
+ p.grad = None
513
+
514
+ if do_report:
515
+ print(f"Best loss: {best_loss}")
516
+ return best_params, acc_loss
517
+ else:
518
+ return best_params
519
+
520
+
521
+ def quantize(
522
+ x: torch.Tensor,
523
+ params: Dict[str, nn.Parameter],
524
+ func: nn.Module,
525
+ bits: int,
526
+ target_dtype: torch.dtype = torch.int8
527
+ ) -> torch.Tensor:
528
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
529
+ x = x.transpose(0, 1) # Aligns shapes
530
+ x = func(x=x, **params)
531
+ x = x.transpose(0, 1)
532
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
533
+ return x
534
+
535
+
536
+ def dequantize(
537
+ x: torch.Tensor,
538
+ params: Dict[str, nn.Parameter],
539
+ func: nn.Module,
540
+ bits: int,
541
+ out_dtype: torch.dtype
542
+ ) -> torch.Tensor:
543
+ x = x.to(dtype=out_dtype)
544
+ x = x.transpose(0, 1)
545
+ x = func(x=x, **params)
546
+ x = x.transpose(0, 1)
547
+ return x
548
+
549
+
550
+ def round_func_BPDA(input):
551
+ # This is equivalent to replacing round function (non-differentiable) with
552
+ # an identity function (differentiable) only when backward.
553
+ forward_value = torch.round(input)
554
+ out = input.clone()
555
+ out.data = forward_value.data
556
+ return out
557
+
558
+
559
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
560
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
561
+
562
+
563
+
564
+ ############## Numpy ###############
565
+
566
+ def np_domain_guard(
567
+ x: np.ndarray,
568
+ min: float = None,
569
+ max: float = None,
570
+ posinf: float = None,
571
+ neginf: float = None,
572
+ nan: float = None
573
+ ) -> np.ndarray:
574
+ """Guard a tensor to a valid domain."""
575
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
576
+ if min is not None or max is not None:
577
+ x = np.clip(x, min, max)
578
+ return x
579
+
580
+
581
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
582
+ """Replace a number in a tensor with another number.
583
+
584
+ Args:
585
+ x (np.ndarray): The input tensor.
586
+ num (float): The number to replace.
587
+ to (float): The number to replace with.
588
+
589
+ Returns:
590
+ np.ndarray: The tensor with the number replaced.
591
+ """
592
+ return np.where(x == num, to, x)
593
+
594
+
595
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
596
+ """Guard the power operation to a valid domain."""
597
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
598
+
fn_gen/rnd_search_t_no_sched/16/loss.png ADDED
fn_gen/rnd_search_t_no_sched/16/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/18/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/18/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x/_s
2
+ _s*x
fn_gen/rnd_search_t_no_sched/18/fn.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (params['_s'] * x)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
52
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
53
+ # left it here for future reference. Will be removed later.
54
+ # scale = scale + 0.01 * torch.randn_like(scale)
55
+
56
+ return scale
57
+
58
+
59
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
60
+ params = {
61
+ }
62
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
63
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
64
+
65
+ if 'post_init_hook' in kwargs:
66
+ kwargs['post_init_hook'](parameters=params)
67
+
68
+ params = learn_parameters(x, params,
69
+ qtz_func=quantization,
70
+ deqtz_func=dequantization,
71
+ bits=kwargs['bits'],
72
+ target_dtype=torch.int8,
73
+ epochs=500,
74
+ early_stop=False,
75
+ )
76
+ if 'post_train_hook' in kwargs:
77
+ kwargs['post_train_hook'](parameters=params)
78
+
79
+ return params
80
+
81
+
82
+ ############### Numpy Qtz ###############
83
+
84
+
85
+ def np_quantization(x, _s):
86
+ return (x * np.divide(1, np_replace_num(_s, num=0, to=10000)))
87
+
88
+
89
+ def np_dequantization(x, _s):
90
+ return (_s * x)
91
+
92
+
93
+ def fit_func(x, _s):
94
+ x_ = np_quantization(x, _s)
95
+ x_ = np_dequantization(x_, _s)
96
+ return x_
97
+
98
+
99
+
100
+ ############### HELPERS ###############
101
+
102
+ def domain_guard(
103
+ x: torch.Tensor,
104
+ min: float = None,
105
+ max: float = None,
106
+ posinf: float = None,
107
+ neginf: float = None,
108
+ nan: float = None
109
+ ) -> torch.Tensor:
110
+ """Guard a tensor to a valid domain."""
111
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
112
+ if min is not None or max is not None:
113
+ x = torch.clamp(x, min=min, max=max)
114
+ return x
115
+
116
+
117
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
118
+ """Replace a number in a tensor with another number.
119
+
120
+ Args:
121
+ x (torch.Tensor): The input tensor.
122
+ num (float): The number to replace.
123
+ to (float): The number to replace with.
124
+
125
+ Returns:
126
+ torch.Tensor: The tensor with the number replaced.
127
+ """
128
+ return torch.where(x == num, to, x)
129
+
130
+
131
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
132
+ """Guard the power operation to a valid domain."""
133
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
134
+
135
+
136
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
137
+ val = torch.amin(x, dim=1)
138
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
139
+
140
+
141
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
142
+ val = torch.amin(x, dim=1)
143
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
144
+
145
+
146
+ def init_space_search(
147
+ x: torch.Tensor,
148
+ **kwargs: Dict[str, Any],
149
+ ) -> torch.Tensor:
150
+
151
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
152
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
153
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
154
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
155
+
156
+ def _search_param(tensors: List[torch.tensor], n_params):
157
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
158
+ torch_tensors = torch.stack(tensors)
159
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
160
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
161
+ mean = torch.mean(torch_tensors, dim=0)
162
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
163
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
164
+
165
+ def _calc(x, qtz_func, deqtz_func, **params):
166
+ x_ = x.transpose(0, 1)
167
+ x_ = qtz_func(x=x_, **params)
168
+ x_ = deqtz_func(x=x_, **params)
169
+ x_ = x_.transpose(0, 1)
170
+ return x_
171
+
172
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
173
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
174
+ assert "params_list" in kwargs, "params list must be provided."
175
+ assert "param" in kwargs, "param must be provided."
176
+
177
+ qtz_func = kwargs.get('qtz_func')
178
+ deqtz_func = kwargs.get('deqtz_func')
179
+ params_list = kwargs.get('params_list')
180
+ param = kwargs.get('param')
181
+
182
+ n_runs = 50 # Number of runs to try to find the best parameters
183
+ n_random_params = 50 # Number of random parameters to generate
184
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
185
+ max_initial = 10000 # Maximum value to initialize the parameters
186
+
187
+ # Initializes the parameters
188
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
189
+ params = _build_initial_param(x, max_initial, n_random_params)
190
+
191
+ # Performs the search
192
+ for _ in range(n_runs):
193
+
194
+ best_params = []
195
+ for param_ in params:
196
+ try:
197
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
198
+ loss_ones = nn.MSELoss()(x, x_)
199
+
200
+ if len(best_params) < n_best_to_pick:
201
+ best_params.append((param_, loss_ones.item()))
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+ elif loss_ones < best_params[-1][1]:
204
+ best_params[-1] = (param_, loss_ones.item())
205
+ best_params = sorted(best_params, key=lambda x: x[1])
206
+
207
+ except Exception: # The parameters might not be valid for the function's domain
208
+ continue
209
+
210
+ # Generates new parameters around the mean
211
+ params = _search_param([p for p, _ in best_params], n_random_params)
212
+
213
+ # Checks if the best parameter is better than the init_ones
214
+ p_ones = init_ones(x, **kwargs)
215
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
216
+ loss_ones = nn.MSELoss()(x, x_)
217
+
218
+ # Checks if the best parameter is better than the init_rand
219
+ p_rand = init_rand(x, **kwargs)
220
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
221
+ loss_rand = nn.MSELoss()(x, x_)
222
+
223
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
224
+ return p_rand
225
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
226
+ return p_ones
227
+ else:
228
+ return best_params[0][0]
229
+
230
+
231
+ def init_linear_scale( # Symmetric scale. From the study folder
232
+ x: torch.Tensor,
233
+ **kwargs: Dict[str, Any],
234
+ ) -> torch.Tensor:
235
+ assert "bits" in kwargs, "bits must be provided."
236
+ assert "params" in kwargs, "params must be provided."
237
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
238
+
239
+ bits = kwargs.get('bits')
240
+ params = kwargs.get('params')
241
+ qtz_func = kwargs.get('qtz_func')
242
+
243
+ x_ = x.transpose(0, 1)
244
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
245
+ x_ = x_.transpose(0, 1)
246
+
247
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
248
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
249
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
250
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
251
+
252
+ eps = torch.finfo(torch.float32).eps
253
+
254
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
255
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
256
+
257
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
258
+
259
+ # Introduces some noise in scale
260
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
261
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
262
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
263
+ # left it here for future reference. Will be removed later.
264
+ # scale = scale + 0.01 * torch.randn_like(scale)
265
+
266
+ return scale
267
+
268
+
269
+ def init_non_linear_regression_fit(
270
+ x: torch.Tensor,
271
+ **kwargs: Dict[str, Any],
272
+ ) -> torch.Tensor:
273
+
274
+ assert "params_list" in kwargs, "params list must be provided."
275
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
276
+ assert "p0" in kwargs, "p0 must be provided."
277
+ np_fit_func = kwargs.get('np_fit_func')
278
+ params_list = kwargs.get('params_list')
279
+ p0 = kwargs.get('p0')
280
+
281
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
282
+ popt, _ = curve_fit(
283
+ func,
284
+ xdata,
285
+ ydata,
286
+ maxfev=1000,
287
+ p0=p0,
288
+ method='lm'
289
+ )
290
+ return popt
291
+
292
+ # 1. Needs to convert the torch tensor to numpy tensor
293
+ xdata = x.cpu().numpy()
294
+
295
+ # 2. Sorts the data so that it makes it easier to fit to it
296
+ sorted_xdata = np.sort(xdata, axis=-1)
297
+
298
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
299
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
300
+
301
+ # 3. Finds the best parameters for each channel
302
+ try:
303
+ params = []
304
+ for i in range(sorted_xdata.shape[0]):
305
+ xdata_ = sorted_xdata[i]
306
+ p0_ = [p0[p][i] for p in params_list]
307
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
308
+ params.append(ch_params)
309
+
310
+ # 4. Builds the parameters
311
+ result = {}
312
+ for i, p in enumerate(params_list):
313
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
314
+
315
+ return result
316
+
317
+ except ValueError as e:
318
+ print(f"Could not fit the function with error: {e}")
319
+ print(f"Using fallback result...")
320
+ return {
321
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
322
+ }
323
+
324
+
325
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
326
+ val = torch.amin(x, dim=1)
327
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
328
+
329
+
330
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
331
+ # Calculate the original minimum and maximum values
332
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
333
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
334
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
335
+
336
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
337
+ return torch.ones_like(x_min)
338
+
339
+ # Calculate the scale factor
340
+ scale = (_max - _min) / (x_max - x_min)
341
+ return scale
342
+
343
+
344
+
345
+ ############## Quant ###############
346
+
347
+ @torch.enable_grad()
348
+ def learn_parameters(
349
+ x: torch.Tensor,
350
+ params: Dict[str, nn.Parameter],
351
+ qtz_func: nn.Module,
352
+ deqtz_func: nn.Module,
353
+ bits: int,
354
+ target_dtype: torch.dtype,
355
+ epochs: int = 1000,
356
+ early_stop: bool = True,
357
+ do_report: bool = False
358
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
359
+ loss_fn = nn.MSELoss()
360
+
361
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
362
+ # the order of magnitude of the loss divided by 2
363
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
364
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
365
+ loss = loss_fn(x, dequant)
366
+
367
+ base_lr = 0.1
368
+ exponent = int(np.floor(np.log10(loss.item())))
369
+ lr = base_lr * (10 ** (exponent // 2))
370
+
371
+ # Requires gradients in the parameters
372
+ for p in params.values():
373
+ p.requires_grad = True
374
+ p.grad = None
375
+
376
+ param_keys = list(params.keys())
377
+ param_values = list(params.values())
378
+
379
+ # Defines optimizer and loss function
380
+ optimizer = torch.optim.Adam(param_values, lr=lr)
381
+
382
+ # Contains the best loss and the best parameters
383
+ best_loss = float("inf")
384
+ best_params = None
385
+
386
+ # Used to stop the search early
387
+ min_delta = 1e-7
388
+ acc_loss = []
389
+ percent_epochs_before_stop = 0.1
390
+
391
+ for i in range(epochs):
392
+ optimizer.zero_grad()
393
+
394
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
395
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
396
+ loss = loss_fn(x, dequant)
397
+
398
+ if loss.isnan() or loss.isinf():
399
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
400
+
401
+ loss.backward()
402
+ optimizer.step()
403
+
404
+ acc_loss.append(loss.item())
405
+
406
+ # Reports loss every 10 steps
407
+ if i % 10 == 0 and do_report:
408
+ print(f"Epoch {i}: Loss {loss.item()}")
409
+
410
+ # Optimizes the parameter search by storing the best loss and the parameters
411
+ if loss.item() < best_loss:
412
+ best_loss = loss.item()
413
+ best_params = copy.deepcopy({
414
+ k: v for k, v in params.items() if k in param_keys
415
+ })
416
+
417
+ # We also stop the search if the loss has not considerably during the last 10% epochs
418
+ if early_stop:
419
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
420
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
421
+ break
422
+
423
+ # No longer requires gradients in the parameters
424
+ for p in best_params.values():
425
+ p.requires_grad = False
426
+ p.grad = None
427
+
428
+ if do_report:
429
+ print(f"Best loss: {best_loss}")
430
+ return best_params, acc_loss
431
+ else:
432
+ return best_params
433
+
434
+
435
+ def quantize(
436
+ x: torch.Tensor,
437
+ params: Dict[str, nn.Parameter],
438
+ func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype = torch.int8
441
+ ) -> torch.Tensor:
442
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
443
+ x = x.transpose(0, 1) # Aligns shapes
444
+ x = func(x=x, **params)
445
+ x = x.transpose(0, 1)
446
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
447
+ return x
448
+
449
+
450
+ def dequantize(
451
+ x: torch.Tensor,
452
+ params: Dict[str, nn.Parameter],
453
+ func: nn.Module,
454
+ bits: int,
455
+ out_dtype: torch.dtype
456
+ ) -> torch.Tensor:
457
+ x = x.to(dtype=out_dtype)
458
+ x = x.transpose(0, 1)
459
+ x = func(x=x, **params)
460
+ x = x.transpose(0, 1)
461
+ return x
462
+
463
+
464
+ def round_func_BPDA(input):
465
+ # This is equivalent to replacing round function (non-differentiable) with
466
+ # an identity function (differentiable) only when backward.
467
+ forward_value = torch.round(input)
468
+ out = input.clone()
469
+ out.data = forward_value.data
470
+ return out
471
+
472
+
473
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
474
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
475
+
476
+
477
+
478
+ ############## Numpy ###############
479
+
480
+ def np_domain_guard(
481
+ x: np.ndarray,
482
+ min: float = None,
483
+ max: float = None,
484
+ posinf: float = None,
485
+ neginf: float = None,
486
+ nan: float = None
487
+ ) -> np.ndarray:
488
+ """Guard a tensor to a valid domain."""
489
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
490
+ if min is not None or max is not None:
491
+ x = np.clip(x, min, max)
492
+ return x
493
+
494
+
495
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
496
+ """Replace a number in a tensor with another number.
497
+
498
+ Args:
499
+ x (np.ndarray): The input tensor.
500
+ num (float): The number to replace.
501
+ to (float): The number to replace with.
502
+
503
+ Returns:
504
+ np.ndarray: The tensor with the number replaced.
505
+ """
506
+ return np.where(x == num, to, x)
507
+
508
+
509
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
510
+ """Guard the power operation to a valid domain."""
511
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
512
+
fn_gen/rnd_search_t_no_sched/18/loss.png ADDED
fn_gen/rnd_search_t_no_sched/18/quantization.png ADDED
fn_gen/rnd_search_t_no_sched/2/distortion.png ADDED
fn_gen/rnd_search_t_no_sched/2/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**3/_s
2
+ (_s*x)**(1/3)
fn_gen/rnd_search_t_no_sched/2/fn.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return guarded_torch_power((params['_s'] * x), 1 / 3)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
52
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
53
+ # left it here for future reference. Will be removed later.
54
+ # scale = scale + 0.01 * torch.randn_like(scale)
55
+
56
+ return scale
57
+
58
+
59
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
60
+ params = {
61
+ }
62
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
63
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
64
+
65
+ if 'post_init_hook' in kwargs:
66
+ kwargs['post_init_hook'](parameters=params)
67
+
68
+ params = learn_parameters(x, params,
69
+ qtz_func=quantization,
70
+ deqtz_func=dequantization,
71
+ bits=kwargs['bits'],
72
+ target_dtype=torch.int8,
73
+ epochs=500,
74
+ early_stop=False,
75
+ )
76
+ if 'post_train_hook' in kwargs:
77
+ kwargs['post_train_hook'](parameters=params)
78
+
79
+ return params
80
+
81
+
82
+ ############### Numpy Qtz ###############
83
+
84
+
85
+ def np_quantization(x, _s):
86
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
87
+
88
+
89
+ def np_dequantization(x, _s):
90
+ return np_guarded_power((_s * x), 1 / 3)
91
+
92
+
93
+ def fit_func(x, _s):
94
+ x_ = np_quantization(x, _s)
95
+ x_ = np_dequantization(x_, _s)
96
+ return x_
97
+
98
+
99
+
100
+ ############### HELPERS ###############
101
+
102
+ def domain_guard(
103
+ x: torch.Tensor,
104
+ min: float = None,
105
+ max: float = None,
106
+ posinf: float = None,
107
+ neginf: float = None,
108
+ nan: float = None
109
+ ) -> torch.Tensor:
110
+ """Guard a tensor to a valid domain."""
111
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
112
+ if min is not None or max is not None:
113
+ x = torch.clamp(x, min=min, max=max)
114
+ return x
115
+
116
+
117
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
118
+ """Replace a number in a tensor with another number.
119
+
120
+ Args:
121
+ x (torch.Tensor): The input tensor.
122
+ num (float): The number to replace.
123
+ to (float): The number to replace with.
124
+
125
+ Returns:
126
+ torch.Tensor: The tensor with the number replaced.
127
+ """
128
+ return torch.where(x == num, to, x)
129
+
130
+
131
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
132
+ """Guard the power operation to a valid domain."""
133
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
134
+
135
+
136
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
137
+ val = torch.amin(x, dim=1)
138
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
139
+
140
+
141
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
142
+ val = torch.amin(x, dim=1)
143
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
144
+
145
+
146
+ def init_space_search(
147
+ x: torch.Tensor,
148
+ **kwargs: Dict[str, Any],
149
+ ) -> torch.Tensor:
150
+
151
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
152
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
153
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
154
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
155
+
156
+ def _search_param(tensors: List[torch.tensor], n_params):
157
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
158
+ torch_tensors = torch.stack(tensors)
159
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
160
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
161
+ mean = torch.mean(torch_tensors, dim=0)
162
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
163
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
164
+
165
+ def _calc(x, qtz_func, deqtz_func, **params):
166
+ x_ = x.transpose(0, 1)
167
+ x_ = qtz_func(x=x_, **params)
168
+ x_ = deqtz_func(x=x_, **params)
169
+ x_ = x_.transpose(0, 1)
170
+ return x_
171
+
172
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
173
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
174
+ assert "params_list" in kwargs, "params list must be provided."
175
+ assert "param" in kwargs, "param must be provided."
176
+
177
+ qtz_func = kwargs.get('qtz_func')
178
+ deqtz_func = kwargs.get('deqtz_func')
179
+ params_list = kwargs.get('params_list')
180
+ param = kwargs.get('param')
181
+
182
+ n_runs = 50 # Number of runs to try to find the best parameters
183
+ n_random_params = 50 # Number of random parameters to generate
184
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
185
+ max_initial = 10000 # Maximum value to initialize the parameters
186
+
187
+ # Initializes the parameters
188
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
189
+ params = _build_initial_param(x, max_initial, n_random_params)
190
+
191
+ # Performs the search
192
+ for _ in range(n_runs):
193
+
194
+ best_params = []
195
+ for param_ in params:
196
+ try:
197
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
198
+ loss_ones = nn.MSELoss()(x, x_)
199
+
200
+ if len(best_params) < n_best_to_pick:
201
+ best_params.append((param_, loss_ones.item()))
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+ elif loss_ones < best_params[-1][1]:
204
+ best_params[-1] = (param_, loss_ones.item())
205
+ best_params = sorted(best_params, key=lambda x: x[1])
206
+
207
+ except Exception: # The parameters might not be valid for the function's domain
208
+ continue
209
+
210
+ # Generates new parameters around the mean
211
+ params = _search_param([p for p, _ in best_params], n_random_params)
212
+
213
+ # Checks if the best parameter is better than the init_ones
214
+ p_ones = init_ones(x, **kwargs)
215
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
216
+ loss_ones = nn.MSELoss()(x, x_)
217
+
218
+ # Checks if the best parameter is better than the init_rand
219
+ p_rand = init_rand(x, **kwargs)
220
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
221
+ loss_rand = nn.MSELoss()(x, x_)
222
+
223
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
224
+ return p_rand
225
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
226
+ return p_ones
227
+ else:
228
+ return best_params[0][0]
229
+
230
+
231
+ def init_linear_scale( # Symmetric scale. From the study folder
232
+ x: torch.Tensor,
233
+ **kwargs: Dict[str, Any],
234
+ ) -> torch.Tensor:
235
+ assert "bits" in kwargs, "bits must be provided."
236
+ assert "params" in kwargs, "params must be provided."
237
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
238
+
239
+ bits = kwargs.get('bits')
240
+ params = kwargs.get('params')
241
+ qtz_func = kwargs.get('qtz_func')
242
+
243
+ x_ = x.transpose(0, 1)
244
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
245
+ x_ = x_.transpose(0, 1)
246
+
247
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
248
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
249
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
250
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
251
+
252
+ eps = torch.finfo(torch.float32).eps
253
+
254
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
255
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
256
+
257
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
258
+
259
+ # Introduces some noise in scale
260
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
261
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
262
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
263
+ # left it here for future reference. Will be removed later.
264
+ # scale = scale + 0.01 * torch.randn_like(scale)
265
+
266
+ return scale
267
+
268
+
269
+ def init_non_linear_regression_fit(
270
+ x: torch.Tensor,
271
+ **kwargs: Dict[str, Any],
272
+ ) -> torch.Tensor:
273
+
274
+ assert "params_list" in kwargs, "params list must be provided."
275
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
276
+ assert "p0" in kwargs, "p0 must be provided."
277
+ np_fit_func = kwargs.get('np_fit_func')
278
+ params_list = kwargs.get('params_list')
279
+ p0 = kwargs.get('p0')
280
+
281
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
282
+ popt, _ = curve_fit(
283
+ func,
284
+ xdata,
285
+ ydata,
286
+ maxfev=1000,
287
+ p0=p0,
288
+ method='lm'
289
+ )
290
+ return popt
291
+
292
+ # 1. Needs to convert the torch tensor to numpy tensor
293
+ xdata = x.cpu().numpy()
294
+
295
+ # 2. Sorts the data so that it makes it easier to fit to it
296
+ sorted_xdata = np.sort(xdata, axis=-1)
297
+
298
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
299
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
300
+
301
+ # 3. Finds the best parameters for each channel
302
+ try:
303
+ params = []
304
+ for i in range(sorted_xdata.shape[0]):
305
+ xdata_ = sorted_xdata[i]
306
+ p0_ = [p0[p][i] for p in params_list]
307
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
308
+ params.append(ch_params)
309
+
310
+ # 4. Builds the parameters
311
+ result = {}
312
+ for i, p in enumerate(params_list):
313
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
314
+
315
+ return result
316
+
317
+ except ValueError as e:
318
+ print(f"Could not fit the function with error: {e}")
319
+ print(f"Using fallback result...")
320
+ return {
321
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
322
+ }
323
+
324
+
325
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
326
+ val = torch.amin(x, dim=1)
327
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
328
+
329
+
330
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
331
+ # Calculate the original minimum and maximum values
332
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
333
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
334
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
335
+
336
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
337
+ return torch.ones_like(x_min)
338
+
339
+ # Calculate the scale factor
340
+ scale = (_max - _min) / (x_max - x_min)
341
+ return scale
342
+
343
+
344
+
345
+ ############## Quant ###############
346
+
347
+ @torch.enable_grad()
348
+ def learn_parameters(
349
+ x: torch.Tensor,
350
+ params: Dict[str, nn.Parameter],
351
+ qtz_func: nn.Module,
352
+ deqtz_func: nn.Module,
353
+ bits: int,
354
+ target_dtype: torch.dtype,
355
+ epochs: int = 1000,
356
+ early_stop: bool = True,
357
+ do_report: bool = False
358
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
359
+ loss_fn = nn.MSELoss()
360
+
361
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
362
+ # the order of magnitude of the loss divided by 2
363
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
364
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
365
+ loss = loss_fn(x, dequant)
366
+
367
+ base_lr = 0.1
368
+ exponent = int(np.floor(np.log10(loss.item())))
369
+ lr = base_lr * (10 ** (exponent // 2))
370
+
371
+ # Requires gradients in the parameters
372
+ for p in params.values():
373
+ p.requires_grad = True
374
+ p.grad = None
375
+
376
+ param_keys = list(params.keys())
377
+ param_values = list(params.values())
378
+
379
+ # Defines optimizer and loss function
380
+ optimizer = torch.optim.Adam(param_values, lr=lr)
381
+
382
+ # Contains the best loss and the best parameters
383
+ best_loss = float("inf")
384
+ best_params = None
385
+
386
+ # Used to stop the search early
387
+ min_delta = 1e-7
388
+ acc_loss = []
389
+ percent_epochs_before_stop = 0.1
390
+
391
+ for i in range(epochs):
392
+ optimizer.zero_grad()
393
+
394
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
395
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
396
+ loss = loss_fn(x, dequant)
397
+
398
+ if loss.isnan() or loss.isinf():
399
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
400
+
401
+ loss.backward()
402
+ optimizer.step()
403
+
404
+ acc_loss.append(loss.item())
405
+
406
+ # Reports loss every 10 steps
407
+ if i % 10 == 0 and do_report:
408
+ print(f"Epoch {i}: Loss {loss.item()}")
409
+
410
+ # Optimizes the parameter search by storing the best loss and the parameters
411
+ if loss.item() < best_loss:
412
+ best_loss = loss.item()
413
+ best_params = copy.deepcopy({
414
+ k: v for k, v in params.items() if k in param_keys
415
+ })
416
+
417
+ # We also stop the search if the loss has not considerably during the last 10% epochs
418
+ if early_stop:
419
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
420
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
421
+ break
422
+
423
+ # No longer requires gradients in the parameters
424
+ for p in best_params.values():
425
+ p.requires_grad = False
426
+ p.grad = None
427
+
428
+ if do_report:
429
+ print(f"Best loss: {best_loss}")
430
+ return best_params, acc_loss
431
+ else:
432
+ return best_params
433
+
434
+
435
+ def quantize(
436
+ x: torch.Tensor,
437
+ params: Dict[str, nn.Parameter],
438
+ func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype = torch.int8
441
+ ) -> torch.Tensor:
442
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
443
+ x = x.transpose(0, 1) # Aligns shapes
444
+ x = func(x=x, **params)
445
+ x = x.transpose(0, 1)
446
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
447
+ return x
448
+
449
+
450
+ def dequantize(
451
+ x: torch.Tensor,
452
+ params: Dict[str, nn.Parameter],
453
+ func: nn.Module,
454
+ bits: int,
455
+ out_dtype: torch.dtype
456
+ ) -> torch.Tensor:
457
+ x = x.to(dtype=out_dtype)
458
+ x = x.transpose(0, 1)
459
+ x = func(x=x, **params)
460
+ x = x.transpose(0, 1)
461
+ return x
462
+
463
+
464
+ def round_func_BPDA(input):
465
+ # This is equivalent to replacing round function (non-differentiable) with
466
+ # an identity function (differentiable) only when backward.
467
+ forward_value = torch.round(input)
468
+ out = input.clone()
469
+ out.data = forward_value.data
470
+ return out
471
+
472
+
473
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
474
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
475
+
476
+
477
+
478
+ ############## Numpy ###############
479
+
480
+ def np_domain_guard(
481
+ x: np.ndarray,
482
+ min: float = None,
483
+ max: float = None,
484
+ posinf: float = None,
485
+ neginf: float = None,
486
+ nan: float = None
487
+ ) -> np.ndarray:
488
+ """Guard a tensor to a valid domain."""
489
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
490
+ if min is not None or max is not None:
491
+ x = np.clip(x, min, max)
492
+ return x
493
+
494
+
495
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
496
+ """Replace a number in a tensor with another number.
497
+
498
+ Args:
499
+ x (np.ndarray): The input tensor.
500
+ num (float): The number to replace.
501
+ to (float): The number to replace with.
502
+
503
+ Returns:
504
+ np.ndarray: The tensor with the number replaced.
505
+ """
506
+ return np.where(x == num, to, x)
507
+
508
+
509
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
510
+ """Guard the power operation to a valid domain."""
511
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
512
+
fn_gen/rnd_search_t_no_sched/2/loss.png ADDED
fn_gen/rnd_search_t_no_sched/2/quantization.png ADDED