Diogo-V commited on
Commit
38bdf4b
·
verified ·
1 Parent(s): 6cd7f25

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/nlr/0/distortion.png +0 -0
  2. fn_gen/nlr/0/expressions.txt +2 -0
  3. fn_gen/nlr/0/fn.py +469 -0
  4. fn_gen/nlr/0/loss.png +0 -0
  5. fn_gen/nlr/0/quantization.png +0 -0
  6. fn_gen/nlr/1/distortion.png +0 -0
  7. fn_gen/nlr/1/expressions.txt +2 -0
  8. fn_gen/nlr/1/fn.py +469 -0
  9. fn_gen/nlr/1/loss.png +0 -0
  10. fn_gen/nlr/1/quantization.png +0 -0
  11. fn_gen/nlr/10/distortion.png +0 -0
  12. fn_gen/nlr/10/expressions.txt +2 -0
  13. fn_gen/nlr/10/fn.py +469 -0
  14. fn_gen/nlr/10/loss.png +0 -0
  15. fn_gen/nlr/10/quantization.png +0 -0
  16. fn_gen/nlr/11/distortion.png +0 -0
  17. fn_gen/nlr/11/expressions.txt +2 -0
  18. fn_gen/nlr/11/fn.py +468 -0
  19. fn_gen/nlr/11/loss.png +0 -0
  20. fn_gen/nlr/11/quantization.png +0 -0
  21. fn_gen/nlr/13/distortion.png +0 -0
  22. fn_gen/nlr/13/expressions.txt +2 -0
  23. fn_gen/nlr/13/fn.py +469 -0
  24. fn_gen/nlr/13/loss.png +0 -0
  25. fn_gen/nlr/13/quantization.png +0 -0
  26. fn_gen/nlr/14/distortion.png +0 -0
  27. fn_gen/nlr/14/expressions.txt +2 -0
  28. fn_gen/nlr/14/fn.py +469 -0
  29. fn_gen/nlr/14/loss.png +0 -0
  30. fn_gen/nlr/14/quantization.png +0 -0
  31. fn_gen/nlr/15/distortion.png +0 -0
  32. fn_gen/nlr/15/expressions.txt +2 -0
  33. fn_gen/nlr/15/fn.py +469 -0
  34. fn_gen/nlr/15/loss.png +0 -0
  35. fn_gen/nlr/15/quantization.png +0 -0
  36. fn_gen/nlr/16/distortion.png +0 -0
  37. fn_gen/nlr/16/expressions.txt +2 -0
  38. fn_gen/nlr/16/fn.py +469 -0
  39. fn_gen/nlr/16/loss.png +0 -0
  40. fn_gen/nlr/16/quantization.png +0 -0
  41. fn_gen/nlr/17/distortion.png +0 -0
  42. fn_gen/nlr/17/expressions.txt +2 -0
  43. fn_gen/nlr/17/fn.py +469 -0
  44. fn_gen/nlr/17/loss.png +0 -0
  45. fn_gen/nlr/17/quantization.png +0 -0
  46. fn_gen/nlr/3/distortion.png +0 -0
  47. fn_gen/nlr/3/expressions.txt +2 -0
  48. fn_gen/nlr/3/fn.py +468 -0
  49. fn_gen/nlr/3/loss.png +0 -0
  50. fn_gen/nlr/3/quantization.png +0 -0
fn_gen/nlr/0/distortion.png ADDED
fn_gen/nlr/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tanh(_0*x)/_s
2
+ log((-_s*x - 1)/(_s*x - 1))/_0
fn_gen/nlr/0/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/0/loss.png ADDED
fn_gen/nlr/0/quantization.png ADDED
fn_gen/nlr/1/distortion.png ADDED
fn_gen/nlr/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ (_0*x)**(1/3)/_s
2
+ _s**3*x**3/_0
fn_gen/nlr/1/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/1/loss.png ADDED
fn_gen/nlr/1/quantization.png ADDED
fn_gen/nlr/10/distortion.png ADDED
fn_gen/nlr/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ exp(_0*x)/_s
2
+ log(_s*x)/_0
fn_gen/nlr/10/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/10/loss.png ADDED
fn_gen/nlr/10/quantization.png ADDED
fn_gen/nlr/11/distortion.png ADDED
fn_gen/nlr/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**2/_s
2
+ sqrt(_s*x)
fn_gen/nlr/11/fn.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ }
24
+
25
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
26
+ if 'post_init_hook' in kwargs:
27
+ kwargs['post_init_hook'](parameters=base_p0)
28
+
29
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_s'], **kwargs)
30
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
31
+ if 'post_method_hook' in kwargs:
32
+ kwargs['post_method_hook'](parameters=params)
33
+
34
+
35
+ if 'post_train_hook' in kwargs:
36
+ kwargs['post_train_hook'](parameters=params)
37
+
38
+ return params
39
+
40
+
41
+ ############### Numpy Qtz ###############
42
+
43
+
44
+ def np_quantization(x, _s):
45
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
46
+
47
+
48
+ def np_dequantization(x, _s):
49
+ return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
50
+
51
+
52
+ def fit_func(x, _s):
53
+ x_ = np_quantization(x, _s)
54
+ x_ = np_dequantization(x_, _s)
55
+ return x_
56
+
57
+
58
+
59
+ ############### HELPERS ###############
60
+
61
+ def domain_guard(
62
+ x: torch.Tensor,
63
+ min: float = None,
64
+ max: float = None,
65
+ posinf: float = None,
66
+ neginf: float = None,
67
+ nan: float = None
68
+ ) -> torch.Tensor:
69
+ """Guard a tensor to a valid domain."""
70
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
71
+ if min is not None or max is not None:
72
+ x = torch.clamp(x, min=min, max=max)
73
+ return x
74
+
75
+
76
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
77
+ """Replace a number in a tensor with another number.
78
+
79
+ Args:
80
+ x (torch.Tensor): The input tensor.
81
+ num (float): The number to replace.
82
+ to (float): The number to replace with.
83
+
84
+ Returns:
85
+ torch.Tensor: The tensor with the number replaced.
86
+ """
87
+ return torch.where(x == num, to, x)
88
+
89
+
90
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
91
+ """Guard the power operation to a valid domain."""
92
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
93
+
94
+
95
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
96
+ val = torch.amin(x, dim=1)
97
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
98
+
99
+
100
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
101
+ val = torch.amin(x, dim=1)
102
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
103
+
104
+
105
+ def init_space_search(
106
+ x: torch.Tensor,
107
+ **kwargs: Dict[str, Any],
108
+ ) -> torch.Tensor:
109
+
110
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
111
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
112
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
113
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
114
+
115
+ def _search_param(tensors: List[torch.tensor], n_params):
116
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
117
+ torch_tensors = torch.stack(tensors)
118
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
119
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
120
+ mean = torch.mean(torch_tensors, dim=0)
121
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
122
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
123
+
124
+ def _calc(x, qtz_func, deqtz_func, **params):
125
+ x_ = x.transpose(0, 1)
126
+ x_ = qtz_func(x=x_, **params)
127
+ x_ = deqtz_func(x=x_, **params)
128
+ x_ = x_.transpose(0, 1)
129
+ return x_
130
+
131
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
132
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
133
+ assert "params_list" in kwargs, "params list must be provided."
134
+ assert "param" in kwargs, "param must be provided."
135
+
136
+ qtz_func = kwargs.get('qtz_func')
137
+ deqtz_func = kwargs.get('deqtz_func')
138
+ params_list = kwargs.get('params_list')
139
+ param = kwargs.get('param')
140
+
141
+ n_runs = 50 # Number of runs to try to find the best parameters
142
+ n_random_params = 50 # Number of random parameters to generate
143
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
144
+ max_initial = 10000 # Maximum value to initialize the parameters
145
+
146
+ # Initializes the parameters
147
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
148
+ params = _build_initial_param(x, max_initial, n_random_params)
149
+
150
+ # Performs the search
151
+ for _ in range(n_runs):
152
+
153
+ best_params = []
154
+ for param_ in params:
155
+ try:
156
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
157
+ loss_ones = nn.MSELoss()(x, x_)
158
+
159
+ if len(best_params) < n_best_to_pick:
160
+ best_params.append((param_, loss_ones.item()))
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+ elif loss_ones < best_params[-1][1]:
163
+ best_params[-1] = (param_, loss_ones.item())
164
+ best_params = sorted(best_params, key=lambda x: x[1])
165
+
166
+ except Exception: # The parameters might not be valid for the function's domain
167
+ continue
168
+
169
+ # Generates new parameters around the mean
170
+ params = _search_param([p for p, _ in best_params], n_random_params)
171
+
172
+ # Checks if the best parameter is better than the init_ones
173
+ p_ones = init_ones(x, **kwargs)
174
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
175
+ loss_ones = nn.MSELoss()(x, x_)
176
+
177
+ # Checks if the best parameter is better than the init_rand
178
+ p_rand = init_rand(x, **kwargs)
179
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
180
+ loss_rand = nn.MSELoss()(x, x_)
181
+
182
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
183
+ return p_rand
184
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
185
+ return p_ones
186
+ else:
187
+ return best_params[0][0]
188
+
189
+
190
+ def init_linear_scale( # Symmetric scale. From the study folder
191
+ x: torch.Tensor,
192
+ **kwargs: Dict[str, Any],
193
+ ) -> torch.Tensor:
194
+ assert "bits" in kwargs, "bits must be provided."
195
+ assert "params" in kwargs, "params must be provided."
196
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
197
+
198
+ bits = kwargs.get('bits')
199
+ params = kwargs.get('params')
200
+ qtz_func = kwargs.get('qtz_func')
201
+
202
+ x_ = x.transpose(0, 1)
203
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
204
+ x_ = x_.transpose(0, 1)
205
+
206
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
207
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
208
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
209
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
210
+
211
+ eps = torch.finfo(torch.float32).eps
212
+
213
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
214
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
215
+
216
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
217
+
218
+ # Introduces some noise in scale
219
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
220
+ # scale = scale + 0.01 * torch.randn_like(scale)
221
+ return scale
222
+
223
+
224
+ def init_non_linear_regression_fit(
225
+ x: torch.Tensor,
226
+ **kwargs: Dict[str, Any],
227
+ ) -> torch.Tensor:
228
+
229
+ assert "params_list" in kwargs, "params list must be provided."
230
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
231
+ assert "p0" in kwargs, "p0 must be provided."
232
+ np_fit_func = kwargs.get('np_fit_func')
233
+ params_list = kwargs.get('params_list')
234
+ p0 = kwargs.get('p0')
235
+
236
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
237
+ popt, _ = curve_fit(
238
+ func,
239
+ xdata,
240
+ ydata,
241
+ maxfev=1000,
242
+ p0=p0,
243
+ method='lm'
244
+ )
245
+ return popt
246
+
247
+ # 1. Needs to convert the torch tensor to numpy tensor
248
+ xdata = x.cpu().numpy()
249
+
250
+ # 2. Sorts the data so that it makes it easier to fit to it
251
+ sorted_xdata = np.sort(xdata, axis=-1)
252
+
253
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
254
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
255
+
256
+ # 3. Finds the best parameters for each channel
257
+ try:
258
+ params = []
259
+ for i in range(sorted_xdata.shape[0]):
260
+ xdata_ = sorted_xdata[i]
261
+ p0_ = [p0[p][i] for p in params_list]
262
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
263
+ params.append(ch_params)
264
+
265
+ # 4. Builds the parameters
266
+ result = {}
267
+ for i, p in enumerate(params_list):
268
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
269
+
270
+ return result
271
+
272
+ except ValueError as e:
273
+ print(f"Could not fit the function with error: {e}")
274
+ print(f"Using fallback result...")
275
+ return {
276
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
277
+ }
278
+
279
+
280
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
281
+ val = torch.amin(x, dim=1)
282
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
283
+
284
+
285
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
286
+ # Calculate the original minimum and maximum values
287
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
288
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
289
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
290
+
291
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
292
+ return torch.ones_like(x_min)
293
+
294
+ # Calculate the scale factor
295
+ scale = (_max - _min) / (x_max - x_min)
296
+ return scale
297
+
298
+
299
+
300
+ ############## Quant ###############
301
+
302
+ @torch.enable_grad()
303
+ def learn_parameters(
304
+ x: torch.Tensor,
305
+ params: Dict[str, nn.Parameter],
306
+ qtz_func: nn.Module,
307
+ deqtz_func: nn.Module,
308
+ bits: int,
309
+ target_dtype: torch.dtype,
310
+ epochs: int = 1000,
311
+ early_stop: bool = True,
312
+ do_report: bool = False
313
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
314
+ loss_fn = nn.MSELoss()
315
+
316
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
317
+ # the order of magnitude of the loss divided by 2
318
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
319
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
320
+ loss = loss_fn(x, dequant)
321
+
322
+ base_lr = 0.1
323
+ exponent = int(np.floor(np.log10(loss.item())))
324
+ lr = base_lr * (10 ** (exponent // 2))
325
+
326
+ # Requires gradients in the parameters
327
+ for p in params.values():
328
+ p.requires_grad = True
329
+ p.grad = None
330
+
331
+ param_keys = list(params.keys())
332
+ param_values = list(params.values())
333
+
334
+ # Defines optimizer and loss function
335
+ optimizer = torch.optim.Adam(param_values, lr=lr)
336
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
337
+
338
+ # Contains the best loss and the best parameters
339
+ best_loss = float("inf")
340
+ best_params = None
341
+
342
+ # Used to stop the search early
343
+ min_delta = 1e-7
344
+ acc_loss = []
345
+ percent_epochs_before_stop = 0.1
346
+
347
+ for i in range(epochs):
348
+ optimizer.zero_grad()
349
+
350
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
351
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
352
+ loss = loss_fn(x, dequant)
353
+
354
+ if loss.isnan() or loss.isinf():
355
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
356
+
357
+ loss.backward()
358
+ optimizer.step()
359
+ scheduler.step()
360
+
361
+ acc_loss.append(loss.item())
362
+
363
+ # Reports loss every 10 steps
364
+ if i % 10 == 0 and do_report:
365
+ print(f"Epoch {i}: Loss {loss.item()}")
366
+
367
+ # Optimizes the parameter search by storing the best loss and the parameters
368
+ if loss.item() < best_loss:
369
+ best_loss = loss.item()
370
+ best_params = copy.deepcopy({
371
+ k: v for k, v in params.items() if k in param_keys
372
+ })
373
+
374
+ # We also stop the search if the loss has not considerably during the last 10% epochs
375
+ if early_stop:
376
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
377
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
378
+ break
379
+
380
+ # No longer requires gradients in the parameters
381
+ for p in best_params.values():
382
+ p.requires_grad = False
383
+ p.grad = None
384
+
385
+ if do_report:
386
+ return best_params, acc_loss
387
+ else:
388
+ return best_params
389
+
390
+
391
+ def quantize(
392
+ x: torch.Tensor,
393
+ params: Dict[str, nn.Parameter],
394
+ func: nn.Module,
395
+ bits: int,
396
+ target_dtype: torch.dtype = torch.int8
397
+ ) -> torch.Tensor:
398
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
399
+ x = x.transpose(0, 1) # Aligns shapes
400
+ x = func(x=x, **params)
401
+ x = x.transpose(0, 1)
402
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
403
+ return x
404
+
405
+
406
+ def dequantize(
407
+ x: torch.Tensor,
408
+ params: Dict[str, nn.Parameter],
409
+ func: nn.Module,
410
+ bits: int,
411
+ out_dtype: torch.dtype
412
+ ) -> torch.Tensor:
413
+ x = x.to(dtype=out_dtype)
414
+ x = x.transpose(0, 1)
415
+ x = func(x=x, **params)
416
+ x = x.transpose(0, 1)
417
+ return x
418
+
419
+
420
+ def round_func_BPDA(input):
421
+ # This is equivalent to replacing round function (non-differentiable) with
422
+ # an identity function (differentiable) only when backward.
423
+ forward_value = torch.round(input)
424
+ out = input.clone()
425
+ out.data = forward_value.data
426
+ return out
427
+
428
+
429
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
430
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
431
+
432
+
433
+
434
+ ############## Numpy ###############
435
+
436
+ def np_domain_guard(
437
+ x: np.ndarray,
438
+ min: float = None,
439
+ max: float = None,
440
+ posinf: float = None,
441
+ neginf: float = None,
442
+ nan: float = None
443
+ ) -> np.ndarray:
444
+ """Guard a tensor to a valid domain."""
445
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
446
+ if min is not None or max is not None:
447
+ x = np.clip(x, min, max)
448
+ return x
449
+
450
+
451
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
452
+ """Replace a number in a tensor with another number.
453
+
454
+ Args:
455
+ x (np.ndarray): The input tensor.
456
+ num (float): The number to replace.
457
+ to (float): The number to replace with.
458
+
459
+ Returns:
460
+ np.ndarray: The tensor with the number replaced.
461
+ """
462
+ return np.where(x == num, to, x)
463
+
464
+
465
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
466
+ """Guard the power operation to a valid domain."""
467
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
468
+
fn_gen/nlr/11/loss.png ADDED
fn_gen/nlr/11/quantization.png ADDED
fn_gen/nlr/13/distortion.png ADDED
fn_gen/nlr/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/nlr/13/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/13/loss.png ADDED
fn_gen/nlr/13/quantization.png ADDED
fn_gen/nlr/14/distortion.png ADDED
fn_gen/nlr/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asin(_0*x)/_s
2
+ sin(_s*x)/_0
fn_gen/nlr/14/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/14/loss.png ADDED
fn_gen/nlr/14/quantization.png ADDED
fn_gen/nlr/15/distortion.png ADDED
fn_gen/nlr/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/nlr/15/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/15/loss.png ADDED
fn_gen/nlr/15/quantization.png ADDED
fn_gen/nlr/16/distortion.png ADDED
fn_gen/nlr/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/nlr/16/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/16/loss.png ADDED
fn_gen/nlr/16/quantization.png ADDED
fn_gen/nlr/17/distortion.png ADDED
fn_gen/nlr/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/nlr/17/fn.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+
36
+ if 'post_train_hook' in kwargs:
37
+ kwargs['post_train_hook'](parameters=params)
38
+
39
+ return params
40
+
41
+
42
+ ############### Numpy Qtz ###############
43
+
44
+
45
+ def np_quantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
47
+
48
+
49
+ def np_dequantization(x, _0, _s):
50
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
51
+
52
+
53
+ def fit_func(x, _0, _s):
54
+ x_ = np_quantization(x, _0, _s)
55
+ x_ = np_dequantization(x_, _0, _s)
56
+ return x_
57
+
58
+
59
+
60
+ ############### HELPERS ###############
61
+
62
+ def domain_guard(
63
+ x: torch.Tensor,
64
+ min: float = None,
65
+ max: float = None,
66
+ posinf: float = None,
67
+ neginf: float = None,
68
+ nan: float = None
69
+ ) -> torch.Tensor:
70
+ """Guard a tensor to a valid domain."""
71
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
72
+ if min is not None or max is not None:
73
+ x = torch.clamp(x, min=min, max=max)
74
+ return x
75
+
76
+
77
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
78
+ """Replace a number in a tensor with another number.
79
+
80
+ Args:
81
+ x (torch.Tensor): The input tensor.
82
+ num (float): The number to replace.
83
+ to (float): The number to replace with.
84
+
85
+ Returns:
86
+ torch.Tensor: The tensor with the number replaced.
87
+ """
88
+ return torch.where(x == num, to, x)
89
+
90
+
91
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
92
+ """Guard the power operation to a valid domain."""
93
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
94
+
95
+
96
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
102
+ val = torch.amin(x, dim=1)
103
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
104
+
105
+
106
+ def init_space_search(
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+
111
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
112
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
113
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
114
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
115
+
116
+ def _search_param(tensors: List[torch.tensor], n_params):
117
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
118
+ torch_tensors = torch.stack(tensors)
119
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
120
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
121
+ mean = torch.mean(torch_tensors, dim=0)
122
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
123
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
124
+
125
+ def _calc(x, qtz_func, deqtz_func, **params):
126
+ x_ = x.transpose(0, 1)
127
+ x_ = qtz_func(x=x_, **params)
128
+ x_ = deqtz_func(x=x_, **params)
129
+ x_ = x_.transpose(0, 1)
130
+ return x_
131
+
132
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
133
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
134
+ assert "params_list" in kwargs, "params list must be provided."
135
+ assert "param" in kwargs, "param must be provided."
136
+
137
+ qtz_func = kwargs.get('qtz_func')
138
+ deqtz_func = kwargs.get('deqtz_func')
139
+ params_list = kwargs.get('params_list')
140
+ param = kwargs.get('param')
141
+
142
+ n_runs = 50 # Number of runs to try to find the best parameters
143
+ n_random_params = 50 # Number of random parameters to generate
144
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
145
+ max_initial = 10000 # Maximum value to initialize the parameters
146
+
147
+ # Initializes the parameters
148
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
149
+ params = _build_initial_param(x, max_initial, n_random_params)
150
+
151
+ # Performs the search
152
+ for _ in range(n_runs):
153
+
154
+ best_params = []
155
+ for param_ in params:
156
+ try:
157
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
158
+ loss_ones = nn.MSELoss()(x, x_)
159
+
160
+ if len(best_params) < n_best_to_pick:
161
+ best_params.append((param_, loss_ones.item()))
162
+ best_params = sorted(best_params, key=lambda x: x[1])
163
+ elif loss_ones < best_params[-1][1]:
164
+ best_params[-1] = (param_, loss_ones.item())
165
+ best_params = sorted(best_params, key=lambda x: x[1])
166
+
167
+ except Exception: # The parameters might not be valid for the function's domain
168
+ continue
169
+
170
+ # Generates new parameters around the mean
171
+ params = _search_param([p for p, _ in best_params], n_random_params)
172
+
173
+ # Checks if the best parameter is better than the init_ones
174
+ p_ones = init_ones(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
176
+ loss_ones = nn.MSELoss()(x, x_)
177
+
178
+ # Checks if the best parameter is better than the init_rand
179
+ p_rand = init_rand(x, **kwargs)
180
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
181
+ loss_rand = nn.MSELoss()(x, x_)
182
+
183
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
184
+ return p_rand
185
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
186
+ return p_ones
187
+ else:
188
+ return best_params[0][0]
189
+
190
+
191
+ def init_linear_scale( # Symmetric scale. From the study folder
192
+ x: torch.Tensor,
193
+ **kwargs: Dict[str, Any],
194
+ ) -> torch.Tensor:
195
+ assert "bits" in kwargs, "bits must be provided."
196
+ assert "params" in kwargs, "params must be provided."
197
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
198
+
199
+ bits = kwargs.get('bits')
200
+ params = kwargs.get('params')
201
+ qtz_func = kwargs.get('qtz_func')
202
+
203
+ x_ = x.transpose(0, 1)
204
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
205
+ x_ = x_.transpose(0, 1)
206
+
207
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
208
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
209
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
210
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
211
+
212
+ eps = torch.finfo(torch.float32).eps
213
+
214
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
215
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
216
+
217
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
218
+
219
+ # Introduces some noise in scale
220
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
221
+ # scale = scale + 0.01 * torch.randn_like(scale)
222
+ return scale
223
+
224
+
225
+ def init_non_linear_regression_fit(
226
+ x: torch.Tensor,
227
+ **kwargs: Dict[str, Any],
228
+ ) -> torch.Tensor:
229
+
230
+ assert "params_list" in kwargs, "params list must be provided."
231
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
232
+ assert "p0" in kwargs, "p0 must be provided."
233
+ np_fit_func = kwargs.get('np_fit_func')
234
+ params_list = kwargs.get('params_list')
235
+ p0 = kwargs.get('p0')
236
+
237
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
238
+ popt, _ = curve_fit(
239
+ func,
240
+ xdata,
241
+ ydata,
242
+ maxfev=1000,
243
+ p0=p0,
244
+ method='lm'
245
+ )
246
+ return popt
247
+
248
+ # 1. Needs to convert the torch tensor to numpy tensor
249
+ xdata = x.cpu().numpy()
250
+
251
+ # 2. Sorts the data so that it makes it easier to fit to it
252
+ sorted_xdata = np.sort(xdata, axis=-1)
253
+
254
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
255
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
256
+
257
+ # 3. Finds the best parameters for each channel
258
+ try:
259
+ params = []
260
+ for i in range(sorted_xdata.shape[0]):
261
+ xdata_ = sorted_xdata[i]
262
+ p0_ = [p0[p][i] for p in params_list]
263
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
264
+ params.append(ch_params)
265
+
266
+ # 4. Builds the parameters
267
+ result = {}
268
+ for i, p in enumerate(params_list):
269
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
270
+
271
+ return result
272
+
273
+ except ValueError as e:
274
+ print(f"Could not fit the function with error: {e}")
275
+ print(f"Using fallback result...")
276
+ return {
277
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
278
+ }
279
+
280
+
281
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
282
+ val = torch.amin(x, dim=1)
283
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
284
+
285
+
286
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
287
+ # Calculate the original minimum and maximum values
288
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
289
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
290
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
291
+
292
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
293
+ return torch.ones_like(x_min)
294
+
295
+ # Calculate the scale factor
296
+ scale = (_max - _min) / (x_max - x_min)
297
+ return scale
298
+
299
+
300
+
301
+ ############## Quant ###############
302
+
303
+ @torch.enable_grad()
304
+ def learn_parameters(
305
+ x: torch.Tensor,
306
+ params: Dict[str, nn.Parameter],
307
+ qtz_func: nn.Module,
308
+ deqtz_func: nn.Module,
309
+ bits: int,
310
+ target_dtype: torch.dtype,
311
+ epochs: int = 1000,
312
+ early_stop: bool = True,
313
+ do_report: bool = False
314
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
315
+ loss_fn = nn.MSELoss()
316
+
317
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
318
+ # the order of magnitude of the loss divided by 2
319
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
320
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
321
+ loss = loss_fn(x, dequant)
322
+
323
+ base_lr = 0.1
324
+ exponent = int(np.floor(np.log10(loss.item())))
325
+ lr = base_lr * (10 ** (exponent // 2))
326
+
327
+ # Requires gradients in the parameters
328
+ for p in params.values():
329
+ p.requires_grad = True
330
+ p.grad = None
331
+
332
+ param_keys = list(params.keys())
333
+ param_values = list(params.values())
334
+
335
+ # Defines optimizer and loss function
336
+ optimizer = torch.optim.Adam(param_values, lr=lr)
337
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
338
+
339
+ # Contains the best loss and the best parameters
340
+ best_loss = float("inf")
341
+ best_params = None
342
+
343
+ # Used to stop the search early
344
+ min_delta = 1e-7
345
+ acc_loss = []
346
+ percent_epochs_before_stop = 0.1
347
+
348
+ for i in range(epochs):
349
+ optimizer.zero_grad()
350
+
351
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
352
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
353
+ loss = loss_fn(x, dequant)
354
+
355
+ if loss.isnan() or loss.isinf():
356
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
357
+
358
+ loss.backward()
359
+ optimizer.step()
360
+ scheduler.step()
361
+
362
+ acc_loss.append(loss.item())
363
+
364
+ # Reports loss every 10 steps
365
+ if i % 10 == 0 and do_report:
366
+ print(f"Epoch {i}: Loss {loss.item()}")
367
+
368
+ # Optimizes the parameter search by storing the best loss and the parameters
369
+ if loss.item() < best_loss:
370
+ best_loss = loss.item()
371
+ best_params = copy.deepcopy({
372
+ k: v for k, v in params.items() if k in param_keys
373
+ })
374
+
375
+ # We also stop the search if the loss has not considerably during the last 10% epochs
376
+ if early_stop:
377
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
378
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
379
+ break
380
+
381
+ # No longer requires gradients in the parameters
382
+ for p in best_params.values():
383
+ p.requires_grad = False
384
+ p.grad = None
385
+
386
+ if do_report:
387
+ return best_params, acc_loss
388
+ else:
389
+ return best_params
390
+
391
+
392
+ def quantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ target_dtype: torch.dtype = torch.int8
398
+ ) -> torch.Tensor:
399
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
400
+ x = x.transpose(0, 1) # Aligns shapes
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
404
+ return x
405
+
406
+
407
+ def dequantize(
408
+ x: torch.Tensor,
409
+ params: Dict[str, nn.Parameter],
410
+ func: nn.Module,
411
+ bits: int,
412
+ out_dtype: torch.dtype
413
+ ) -> torch.Tensor:
414
+ x = x.to(dtype=out_dtype)
415
+ x = x.transpose(0, 1)
416
+ x = func(x=x, **params)
417
+ x = x.transpose(0, 1)
418
+ return x
419
+
420
+
421
+ def round_func_BPDA(input):
422
+ # This is equivalent to replacing round function (non-differentiable) with
423
+ # an identity function (differentiable) only when backward.
424
+ forward_value = torch.round(input)
425
+ out = input.clone()
426
+ out.data = forward_value.data
427
+ return out
428
+
429
+
430
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
431
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
432
+
433
+
434
+
435
+ ############## Numpy ###############
436
+
437
+ def np_domain_guard(
438
+ x: np.ndarray,
439
+ min: float = None,
440
+ max: float = None,
441
+ posinf: float = None,
442
+ neginf: float = None,
443
+ nan: float = None
444
+ ) -> np.ndarray:
445
+ """Guard a tensor to a valid domain."""
446
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
447
+ if min is not None or max is not None:
448
+ x = np.clip(x, min, max)
449
+ return x
450
+
451
+
452
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
453
+ """Replace a number in a tensor with another number.
454
+
455
+ Args:
456
+ x (np.ndarray): The input tensor.
457
+ num (float): The number to replace.
458
+ to (float): The number to replace with.
459
+
460
+ Returns:
461
+ np.ndarray: The tensor with the number replaced.
462
+ """
463
+ return np.where(x == num, to, x)
464
+
465
+
466
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
467
+ """Guard the power operation to a valid domain."""
468
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
469
+
fn_gen/nlr/17/loss.png ADDED
fn_gen/nlr/17/quantization.png ADDED
fn_gen/nlr/3/distortion.png ADDED
fn_gen/nlr/3/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**3/_s
2
+ (_s*x)**(1/3)
fn_gen/nlr/3/fn.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return guarded_torch_power((params['_s'] * x), 1 / 3)
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ }
24
+
25
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
26
+ if 'post_init_hook' in kwargs:
27
+ kwargs['post_init_hook'](parameters=base_p0)
28
+
29
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_s'], **kwargs)
30
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
31
+ if 'post_method_hook' in kwargs:
32
+ kwargs['post_method_hook'](parameters=params)
33
+
34
+
35
+ if 'post_train_hook' in kwargs:
36
+ kwargs['post_train_hook'](parameters=params)
37
+
38
+ return params
39
+
40
+
41
+ ############### Numpy Qtz ###############
42
+
43
+
44
+ def np_quantization(x, _s):
45
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
46
+
47
+
48
+ def np_dequantization(x, _s):
49
+ return np_guarded_power((_s * x), 1 / 3)
50
+
51
+
52
+ def fit_func(x, _s):
53
+ x_ = np_quantization(x, _s)
54
+ x_ = np_dequantization(x_, _s)
55
+ return x_
56
+
57
+
58
+
59
+ ############### HELPERS ###############
60
+
61
+ def domain_guard(
62
+ x: torch.Tensor,
63
+ min: float = None,
64
+ max: float = None,
65
+ posinf: float = None,
66
+ neginf: float = None,
67
+ nan: float = None
68
+ ) -> torch.Tensor:
69
+ """Guard a tensor to a valid domain."""
70
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
71
+ if min is not None or max is not None:
72
+ x = torch.clamp(x, min=min, max=max)
73
+ return x
74
+
75
+
76
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
77
+ """Replace a number in a tensor with another number.
78
+
79
+ Args:
80
+ x (torch.Tensor): The input tensor.
81
+ num (float): The number to replace.
82
+ to (float): The number to replace with.
83
+
84
+ Returns:
85
+ torch.Tensor: The tensor with the number replaced.
86
+ """
87
+ return torch.where(x == num, to, x)
88
+
89
+
90
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
91
+ """Guard the power operation to a valid domain."""
92
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
93
+
94
+
95
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
96
+ val = torch.amin(x, dim=1)
97
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
98
+
99
+
100
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
101
+ val = torch.amin(x, dim=1)
102
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
103
+
104
+
105
+ def init_space_search(
106
+ x: torch.Tensor,
107
+ **kwargs: Dict[str, Any],
108
+ ) -> torch.Tensor:
109
+
110
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
111
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
112
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
113
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
114
+
115
+ def _search_param(tensors: List[torch.tensor], n_params):
116
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
117
+ torch_tensors = torch.stack(tensors)
118
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
119
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
120
+ mean = torch.mean(torch_tensors, dim=0)
121
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
122
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
123
+
124
+ def _calc(x, qtz_func, deqtz_func, **params):
125
+ x_ = x.transpose(0, 1)
126
+ x_ = qtz_func(x=x_, **params)
127
+ x_ = deqtz_func(x=x_, **params)
128
+ x_ = x_.transpose(0, 1)
129
+ return x_
130
+
131
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
132
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
133
+ assert "params_list" in kwargs, "params list must be provided."
134
+ assert "param" in kwargs, "param must be provided."
135
+
136
+ qtz_func = kwargs.get('qtz_func')
137
+ deqtz_func = kwargs.get('deqtz_func')
138
+ params_list = kwargs.get('params_list')
139
+ param = kwargs.get('param')
140
+
141
+ n_runs = 50 # Number of runs to try to find the best parameters
142
+ n_random_params = 50 # Number of random parameters to generate
143
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
144
+ max_initial = 10000 # Maximum value to initialize the parameters
145
+
146
+ # Initializes the parameters
147
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
148
+ params = _build_initial_param(x, max_initial, n_random_params)
149
+
150
+ # Performs the search
151
+ for _ in range(n_runs):
152
+
153
+ best_params = []
154
+ for param_ in params:
155
+ try:
156
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
157
+ loss_ones = nn.MSELoss()(x, x_)
158
+
159
+ if len(best_params) < n_best_to_pick:
160
+ best_params.append((param_, loss_ones.item()))
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+ elif loss_ones < best_params[-1][1]:
163
+ best_params[-1] = (param_, loss_ones.item())
164
+ best_params = sorted(best_params, key=lambda x: x[1])
165
+
166
+ except Exception: # The parameters might not be valid for the function's domain
167
+ continue
168
+
169
+ # Generates new parameters around the mean
170
+ params = _search_param([p for p, _ in best_params], n_random_params)
171
+
172
+ # Checks if the best parameter is better than the init_ones
173
+ p_ones = init_ones(x, **kwargs)
174
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
175
+ loss_ones = nn.MSELoss()(x, x_)
176
+
177
+ # Checks if the best parameter is better than the init_rand
178
+ p_rand = init_rand(x, **kwargs)
179
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
180
+ loss_rand = nn.MSELoss()(x, x_)
181
+
182
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
183
+ return p_rand
184
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
185
+ return p_ones
186
+ else:
187
+ return best_params[0][0]
188
+
189
+
190
+ def init_linear_scale( # Symmetric scale. From the study folder
191
+ x: torch.Tensor,
192
+ **kwargs: Dict[str, Any],
193
+ ) -> torch.Tensor:
194
+ assert "bits" in kwargs, "bits must be provided."
195
+ assert "params" in kwargs, "params must be provided."
196
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
197
+
198
+ bits = kwargs.get('bits')
199
+ params = kwargs.get('params')
200
+ qtz_func = kwargs.get('qtz_func')
201
+
202
+ x_ = x.transpose(0, 1)
203
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
204
+ x_ = x_.transpose(0, 1)
205
+
206
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
207
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
208
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
209
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
210
+
211
+ eps = torch.finfo(torch.float32).eps
212
+
213
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
214
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
215
+
216
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
217
+
218
+ # Introduces some noise in scale
219
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
220
+ # scale = scale + 0.01 * torch.randn_like(scale)
221
+ return scale
222
+
223
+
224
+ def init_non_linear_regression_fit(
225
+ x: torch.Tensor,
226
+ **kwargs: Dict[str, Any],
227
+ ) -> torch.Tensor:
228
+
229
+ assert "params_list" in kwargs, "params list must be provided."
230
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
231
+ assert "p0" in kwargs, "p0 must be provided."
232
+ np_fit_func = kwargs.get('np_fit_func')
233
+ params_list = kwargs.get('params_list')
234
+ p0 = kwargs.get('p0')
235
+
236
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
237
+ popt, _ = curve_fit(
238
+ func,
239
+ xdata,
240
+ ydata,
241
+ maxfev=1000,
242
+ p0=p0,
243
+ method='lm'
244
+ )
245
+ return popt
246
+
247
+ # 1. Needs to convert the torch tensor to numpy tensor
248
+ xdata = x.cpu().numpy()
249
+
250
+ # 2. Sorts the data so that it makes it easier to fit to it
251
+ sorted_xdata = np.sort(xdata, axis=-1)
252
+
253
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
254
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
255
+
256
+ # 3. Finds the best parameters for each channel
257
+ try:
258
+ params = []
259
+ for i in range(sorted_xdata.shape[0]):
260
+ xdata_ = sorted_xdata[i]
261
+ p0_ = [p0[p][i] for p in params_list]
262
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
263
+ params.append(ch_params)
264
+
265
+ # 4. Builds the parameters
266
+ result = {}
267
+ for i, p in enumerate(params_list):
268
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
269
+
270
+ return result
271
+
272
+ except ValueError as e:
273
+ print(f"Could not fit the function with error: {e}")
274
+ print(f"Using fallback result...")
275
+ return {
276
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
277
+ }
278
+
279
+
280
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
281
+ val = torch.amin(x, dim=1)
282
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
283
+
284
+
285
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
286
+ # Calculate the original minimum and maximum values
287
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
288
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
289
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
290
+
291
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
292
+ return torch.ones_like(x_min)
293
+
294
+ # Calculate the scale factor
295
+ scale = (_max - _min) / (x_max - x_min)
296
+ return scale
297
+
298
+
299
+
300
+ ############## Quant ###############
301
+
302
+ @torch.enable_grad()
303
+ def learn_parameters(
304
+ x: torch.Tensor,
305
+ params: Dict[str, nn.Parameter],
306
+ qtz_func: nn.Module,
307
+ deqtz_func: nn.Module,
308
+ bits: int,
309
+ target_dtype: torch.dtype,
310
+ epochs: int = 1000,
311
+ early_stop: bool = True,
312
+ do_report: bool = False
313
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
314
+ loss_fn = nn.MSELoss()
315
+
316
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
317
+ # the order of magnitude of the loss divided by 2
318
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
319
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
320
+ loss = loss_fn(x, dequant)
321
+
322
+ base_lr = 0.1
323
+ exponent = int(np.floor(np.log10(loss.item())))
324
+ lr = base_lr * (10 ** (exponent // 2))
325
+
326
+ # Requires gradients in the parameters
327
+ for p in params.values():
328
+ p.requires_grad = True
329
+ p.grad = None
330
+
331
+ param_keys = list(params.keys())
332
+ param_values = list(params.values())
333
+
334
+ # Defines optimizer and loss function
335
+ optimizer = torch.optim.Adam(param_values, lr=lr)
336
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
337
+
338
+ # Contains the best loss and the best parameters
339
+ best_loss = float("inf")
340
+ best_params = None
341
+
342
+ # Used to stop the search early
343
+ min_delta = 1e-7
344
+ acc_loss = []
345
+ percent_epochs_before_stop = 0.1
346
+
347
+ for i in range(epochs):
348
+ optimizer.zero_grad()
349
+
350
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
351
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
352
+ loss = loss_fn(x, dequant)
353
+
354
+ if loss.isnan() or loss.isinf():
355
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
356
+
357
+ loss.backward()
358
+ optimizer.step()
359
+ scheduler.step()
360
+
361
+ acc_loss.append(loss.item())
362
+
363
+ # Reports loss every 10 steps
364
+ if i % 10 == 0 and do_report:
365
+ print(f"Epoch {i}: Loss {loss.item()}")
366
+
367
+ # Optimizes the parameter search by storing the best loss and the parameters
368
+ if loss.item() < best_loss:
369
+ best_loss = loss.item()
370
+ best_params = copy.deepcopy({
371
+ k: v for k, v in params.items() if k in param_keys
372
+ })
373
+
374
+ # We also stop the search if the loss has not considerably during the last 10% epochs
375
+ if early_stop:
376
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
377
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
378
+ break
379
+
380
+ # No longer requires gradients in the parameters
381
+ for p in best_params.values():
382
+ p.requires_grad = False
383
+ p.grad = None
384
+
385
+ if do_report:
386
+ return best_params, acc_loss
387
+ else:
388
+ return best_params
389
+
390
+
391
+ def quantize(
392
+ x: torch.Tensor,
393
+ params: Dict[str, nn.Parameter],
394
+ func: nn.Module,
395
+ bits: int,
396
+ target_dtype: torch.dtype = torch.int8
397
+ ) -> torch.Tensor:
398
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
399
+ x = x.transpose(0, 1) # Aligns shapes
400
+ x = func(x=x, **params)
401
+ x = x.transpose(0, 1)
402
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
403
+ return x
404
+
405
+
406
+ def dequantize(
407
+ x: torch.Tensor,
408
+ params: Dict[str, nn.Parameter],
409
+ func: nn.Module,
410
+ bits: int,
411
+ out_dtype: torch.dtype
412
+ ) -> torch.Tensor:
413
+ x = x.to(dtype=out_dtype)
414
+ x = x.transpose(0, 1)
415
+ x = func(x=x, **params)
416
+ x = x.transpose(0, 1)
417
+ return x
418
+
419
+
420
+ def round_func_BPDA(input):
421
+ # This is equivalent to replacing round function (non-differentiable) with
422
+ # an identity function (differentiable) only when backward.
423
+ forward_value = torch.round(input)
424
+ out = input.clone()
425
+ out.data = forward_value.data
426
+ return out
427
+
428
+
429
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
430
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
431
+
432
+
433
+
434
+ ############## Numpy ###############
435
+
436
+ def np_domain_guard(
437
+ x: np.ndarray,
438
+ min: float = None,
439
+ max: float = None,
440
+ posinf: float = None,
441
+ neginf: float = None,
442
+ nan: float = None
443
+ ) -> np.ndarray:
444
+ """Guard a tensor to a valid domain."""
445
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
446
+ if min is not None or max is not None:
447
+ x = np.clip(x, min, max)
448
+ return x
449
+
450
+
451
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
452
+ """Replace a number in a tensor with another number.
453
+
454
+ Args:
455
+ x (np.ndarray): The input tensor.
456
+ num (float): The number to replace.
457
+ to (float): The number to replace with.
458
+
459
+ Returns:
460
+ np.ndarray: The tensor with the number replaced.
461
+ """
462
+ return np.where(x == num, to, x)
463
+
464
+
465
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
466
+ """Guard the power operation to a valid domain."""
467
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
468
+
fn_gen/nlr/3/loss.png ADDED
fn_gen/nlr/3/quantization.png ADDED