Diogo-V commited on
Commit
73acbd2
1 Parent(s): ca05132

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/ones_naive/0/distortion.png +0 -0
  2. fn_gen/ones_naive/0/expressions.txt +2 -0
  3. fn_gen/ones_naive/0/fn.py +455 -0
  4. fn_gen/ones_naive/0/loss.png +0 -0
  5. fn_gen/ones_naive/0/quantization.png +0 -0
  6. fn_gen/ones_naive/1/distortion.png +0 -0
  7. fn_gen/ones_naive/1/expressions.txt +2 -0
  8. fn_gen/ones_naive/1/fn.py +454 -0
  9. fn_gen/ones_naive/1/loss.png +0 -0
  10. fn_gen/ones_naive/1/quantization.png +0 -0
  11. fn_gen/ones_naive/10/distortion.png +0 -0
  12. fn_gen/ones_naive/10/expressions.txt +2 -0
  13. fn_gen/ones_naive/10/fn.py +455 -0
  14. fn_gen/ones_naive/10/loss.png +0 -0
  15. fn_gen/ones_naive/10/quantization.png +0 -0
  16. fn_gen/ones_naive/11/distortion.png +0 -0
  17. fn_gen/ones_naive/11/expressions.txt +2 -0
  18. fn_gen/ones_naive/11/fn.py +455 -0
  19. fn_gen/ones_naive/11/loss.png +0 -0
  20. fn_gen/ones_naive/11/quantization.png +0 -0
  21. fn_gen/ones_naive/12/distortion.png +0 -0
  22. fn_gen/ones_naive/12/expressions.txt +2 -0
  23. fn_gen/ones_naive/12/fn.py +455 -0
  24. fn_gen/ones_naive/12/loss.png +0 -0
  25. fn_gen/ones_naive/12/quantization.png +0 -0
  26. fn_gen/ones_naive/13/distortion.png +0 -0
  27. fn_gen/ones_naive/13/expressions.txt +2 -0
  28. fn_gen/ones_naive/13/fn.py +455 -0
  29. fn_gen/ones_naive/13/loss.png +0 -0
  30. fn_gen/ones_naive/13/quantization.png +0 -0
  31. fn_gen/ones_naive/14/distortion.png +0 -0
  32. fn_gen/ones_naive/14/expressions.txt +2 -0
  33. fn_gen/ones_naive/14/fn.py +455 -0
  34. fn_gen/ones_naive/14/loss.png +0 -0
  35. fn_gen/ones_naive/14/quantization.png +0 -0
  36. fn_gen/ones_naive/15/distortion.png +0 -0
  37. fn_gen/ones_naive/15/expressions.txt +2 -0
  38. fn_gen/ones_naive/15/fn.py +455 -0
  39. fn_gen/ones_naive/15/loss.png +0 -0
  40. fn_gen/ones_naive/15/quantization.png +0 -0
  41. fn_gen/ones_naive/16/distortion.png +0 -0
  42. fn_gen/ones_naive/16/expressions.txt +2 -0
  43. fn_gen/ones_naive/16/fn.py +455 -0
  44. fn_gen/ones_naive/16/loss.png +0 -0
  45. fn_gen/ones_naive/16/quantization.png +0 -0
  46. fn_gen/ones_naive/17/distortion.png +0 -0
  47. fn_gen/ones_naive/17/expressions.txt +2 -0
  48. fn_gen/ones_naive/17/fn.py +455 -0
  49. fn_gen/ones_naive/17/loss.png +0 -0
  50. fn_gen/ones_naive/17/quantization.png +0 -0
fn_gen/ones_naive/0/distortion.png ADDED
fn_gen/ones_naive/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/ones_naive/0/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/0/loss.png ADDED
fn_gen/ones_naive/0/quantization.png ADDED
fn_gen/ones_naive/1/distortion.png ADDED
fn_gen/ones_naive/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**3/_s
2
+ (_s*x)**(1/3)
fn_gen/ones_naive/1/fn.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return guarded_torch_power((params['_s'] * x), 1 / 3)
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ }
24
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
25
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
26
+
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=params)
29
+
30
+
31
+ if 'post_train_hook' in kwargs:
32
+ kwargs['post_train_hook'](parameters=params)
33
+
34
+ return params
35
+
36
+
37
+ ############### Numpy Qtz ###############
38
+
39
+
40
+ def np_quantization(x, _s):
41
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
42
+
43
+
44
+ def np_dequantization(x, _s):
45
+ return np_guarded_power((_s * x), 1 / 3)
46
+
47
+
48
+ def fit_func(x, _s):
49
+ x_ = np_quantization(x, _s)
50
+ x_ = np_dequantization(x_, _s)
51
+ return x_
52
+
53
+
54
+
55
+ ############### HELPERS ###############
56
+
57
+ def domain_guard(
58
+ x: torch.Tensor,
59
+ min: float = None,
60
+ max: float = None,
61
+ posinf: float = None,
62
+ neginf: float = None,
63
+ nan: float = None
64
+ ) -> torch.Tensor:
65
+ """Guard a tensor to a valid domain."""
66
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
67
+ if min is not None or max is not None:
68
+ x = torch.clamp(x, min=min, max=max)
69
+ return x
70
+
71
+
72
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
73
+ """Replace a number in a tensor with another number.
74
+
75
+ Args:
76
+ x (torch.Tensor): The input tensor.
77
+ num (float): The number to replace.
78
+ to (float): The number to replace with.
79
+
80
+ Returns:
81
+ torch.Tensor: The tensor with the number replaced.
82
+ """
83
+ return torch.where(x == num, to, x)
84
+
85
+
86
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
87
+ """Guard the power operation to a valid domain."""
88
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
89
+
90
+
91
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
92
+ val = torch.amin(x, dim=1)
93
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
94
+
95
+
96
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
97
+ val = torch.amin(x, dim=1)
98
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
99
+
100
+
101
+ def init_space_search(
102
+ x: torch.Tensor,
103
+ **kwargs: Dict[str, Any],
104
+ ) -> torch.Tensor:
105
+
106
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
107
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
108
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
109
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
110
+
111
+ def _search_param(tensors: List[torch.tensor], n_params):
112
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
113
+ torch_tensors = torch.stack(tensors)
114
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
115
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
116
+ mean = torch.mean(torch_tensors, dim=0)
117
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
118
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
119
+
120
+ def _calc(x, qtz_func, deqtz_func, **params):
121
+ x_ = x.transpose(0, 1)
122
+ x_ = qtz_func(x=x_, **params)
123
+ x_ = deqtz_func(x=x_, **params)
124
+ x_ = x_.transpose(0, 1)
125
+ return x_
126
+
127
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
128
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
129
+ assert "params_list" in kwargs, "params list must be provided."
130
+ assert "param" in kwargs, "param must be provided."
131
+
132
+ qtz_func = kwargs.get('qtz_func')
133
+ deqtz_func = kwargs.get('deqtz_func')
134
+ params_list = kwargs.get('params_list')
135
+ param = kwargs.get('param')
136
+
137
+ n_runs = 50 # Number of runs to try to find the best parameters
138
+ n_random_params = 50 # Number of random parameters to generate
139
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
140
+ max_initial = 10000 # Maximum value to initialize the parameters
141
+
142
+ # Initializes the parameters
143
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
144
+ params = _build_initial_param(x, max_initial, n_random_params)
145
+
146
+ # Performs the search
147
+ for _ in range(n_runs):
148
+
149
+ best_params = []
150
+ for param_ in params:
151
+ try:
152
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
153
+ loss_ones = nn.MSELoss()(x, x_)
154
+
155
+ if len(best_params) < n_best_to_pick:
156
+ best_params.append((param_, loss_ones.item()))
157
+ best_params = sorted(best_params, key=lambda x: x[1])
158
+ elif loss_ones < best_params[-1][1]:
159
+ best_params[-1] = (param_, loss_ones.item())
160
+ best_params = sorted(best_params, key=lambda x: x[1])
161
+
162
+ except Exception: # The parameters might not be valid for the function's domain
163
+ continue
164
+
165
+ # Generates new parameters around the mean
166
+ params = _search_param([p for p, _ in best_params], n_random_params)
167
+
168
+ # Checks if the best parameter is better than the init_ones
169
+ p_ones = init_ones(x, **kwargs)
170
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
171
+ loss_ones = nn.MSELoss()(x, x_)
172
+
173
+ # Checks if the best parameter is better than the init_rand
174
+ p_rand = init_rand(x, **kwargs)
175
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
176
+ loss_rand = nn.MSELoss()(x, x_)
177
+
178
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
179
+ return p_rand
180
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
181
+ return p_ones
182
+ else:
183
+ return best_params[0][0]
184
+
185
+
186
+ def init_linear_scale( # Symmetric scale. From the study folder
187
+ x: torch.Tensor,
188
+ **kwargs: Dict[str, Any],
189
+ ) -> torch.Tensor:
190
+ assert "bits" in kwargs, "bits must be provided."
191
+ assert "params" in kwargs, "params must be provided."
192
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
193
+
194
+ bits = kwargs.get('bits')
195
+ params = kwargs.get('params')
196
+ qtz_func = kwargs.get('qtz_func')
197
+
198
+ x_ = x.transpose(0, 1)
199
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
200
+ x_ = x_.transpose(0, 1)
201
+
202
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
203
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
204
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
205
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
206
+
207
+ eps = torch.finfo(torch.float32).eps
208
+
209
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
210
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
211
+
212
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
213
+
214
+ # Introduces some noise in scale
215
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
216
+ scale = scale + 0.01 * torch.randn_like(scale)
217
+ return scale
218
+
219
+
220
+ def init_non_linear_regression_fit(
221
+ x: torch.Tensor,
222
+ **kwargs: Dict[str, Any],
223
+ ) -> torch.Tensor:
224
+
225
+ assert "params_list" in kwargs, "params list must be provided."
226
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
227
+ assert "p0" in kwargs, "p0 must be provided."
228
+ np_fit_func = kwargs.get('np_fit_func')
229
+ params_list = kwargs.get('params_list')
230
+ p0 = kwargs.get('p0')
231
+
232
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
233
+ popt, _ = curve_fit(
234
+ func,
235
+ xdata,
236
+ ydata,
237
+ maxfev=1000,
238
+ p0=p0,
239
+ method='lm'
240
+ )
241
+ return popt
242
+
243
+ # 1. Needs to convert the torch tensor to numpy tensor
244
+ xdata = x.cpu().numpy()
245
+
246
+ # 2. Sorts the data so that it makes it easier to fit to it
247
+ sorted_xdata = np.sort(xdata, axis=-1)
248
+
249
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
250
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
251
+
252
+ # 3. Finds the best parameters for each channel
253
+ try:
254
+ params = []
255
+ for i in range(sorted_xdata.shape[0]):
256
+ xdata_ = sorted_xdata[i]
257
+ p0_ = [p0[p][i] for p in params_list]
258
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
259
+ params.append(ch_params)
260
+
261
+ # 4. Builds the parameters
262
+ result = {}
263
+ for i, p in enumerate(params_list):
264
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
265
+
266
+ return result
267
+
268
+ except ValueError as e:
269
+ print(f"Could not fit the function with error: {e}")
270
+ print(f"Using fallback result...")
271
+ return {
272
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
273
+ }
274
+
275
+
276
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
277
+ val = torch.amin(x, dim=1)
278
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
279
+
280
+
281
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
282
+ # Calculate the original minimum and maximum values
283
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
284
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
285
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
286
+
287
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
288
+ return torch.ones_like(x_min)
289
+
290
+ # Calculate the scale factor
291
+ scale = (_max - _min) / (x_max - x_min)
292
+ return scale
293
+
294
+
295
+
296
+ ############## Quant ###############
297
+
298
+ @torch.enable_grad()
299
+ def learn_parameters(
300
+ x: torch.Tensor,
301
+ params: Dict[str, nn.Parameter],
302
+ qtz_func: nn.Module,
303
+ deqtz_func: nn.Module,
304
+ bits: int,
305
+ target_dtype: torch.dtype,
306
+ epochs: int = 1000,
307
+ early_stop: bool = True,
308
+ do_report: bool = False
309
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
310
+
311
+ # Requires gradients in the parameters
312
+ for p in params.values():
313
+ p.requires_grad = True
314
+ p.grad = None
315
+
316
+ param_keys = list(params.keys())
317
+ param_values = list(params.values())
318
+
319
+ # Defines optimizer and loss function
320
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
321
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
322
+ loss_fn = nn.MSELoss()
323
+
324
+ # Contains the best loss and the best parameters
325
+ best_loss = float("inf")
326
+ best_params = None
327
+
328
+ # Used to stop the search early
329
+ min_delta = 1e-7
330
+ acc_loss = []
331
+ percent_epochs_before_stop = 0.1
332
+
333
+ for i in range(epochs):
334
+ optimizer.zero_grad()
335
+
336
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
337
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
338
+ loss = loss_fn(x, dequant)
339
+
340
+ if loss.isnan() or loss.isinf():
341
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
342
+
343
+ loss.backward()
344
+ optimizer.step()
345
+ scheduler.step()
346
+
347
+ acc_loss.append(loss.item())
348
+
349
+ # Reports loss every 10 steps
350
+ if i % 10 == 0 and do_report:
351
+ print(f"Epoch {i}: Loss {loss.item()}")
352
+
353
+ # Optimizes the parameter search by storing the best loss and the parameters
354
+ if loss.item() < best_loss:
355
+ best_loss = loss.item()
356
+ best_params = copy.deepcopy({
357
+ k: v for k, v in params.items() if k in param_keys
358
+ })
359
+
360
+ # We also stop the search if the loss has not considerably during the last 10% epochs
361
+ if early_stop:
362
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
363
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
364
+ break
365
+
366
+ # No longer requires gradients in the parameters
367
+ for p in best_params.values():
368
+ p.requires_grad = False
369
+ p.grad = None
370
+
371
+ if do_report:
372
+ return best_params, acc_loss
373
+ else:
374
+ return best_params
375
+
376
+
377
+ def quantize(
378
+ x: torch.Tensor,
379
+ params: Dict[str, nn.Parameter],
380
+ func: nn.Module,
381
+ bits: int,
382
+ target_dtype: torch.dtype = torch.int8
383
+ ) -> torch.Tensor:
384
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
385
+ x = x.transpose(0, 1) # Aligns shapes
386
+ x = func(x=x, **params)
387
+ x = x.transpose(0, 1)
388
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
389
+ return x
390
+
391
+
392
+ def dequantize(
393
+ x: torch.Tensor,
394
+ params: Dict[str, nn.Parameter],
395
+ func: nn.Module,
396
+ bits: int,
397
+ out_dtype: torch.dtype
398
+ ) -> torch.Tensor:
399
+ x = x.to(dtype=out_dtype)
400
+ x = x.transpose(0, 1)
401
+ x = func(x=x, **params)
402
+ x = x.transpose(0, 1)
403
+ return x
404
+
405
+
406
+ def round_func_BPDA(input):
407
+ # This is equivalent to replacing round function (non-differentiable) with
408
+ # an identity function (differentiable) only when backward.
409
+ forward_value = torch.round(input)
410
+ out = input.clone()
411
+ out.data = forward_value.data
412
+ return out
413
+
414
+
415
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
416
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
417
+
418
+
419
+
420
+ ############## Numpy ###############
421
+
422
+ def np_domain_guard(
423
+ x: np.ndarray,
424
+ min: float = None,
425
+ max: float = None,
426
+ posinf: float = None,
427
+ neginf: float = None,
428
+ nan: float = None
429
+ ) -> np.ndarray:
430
+ """Guard a tensor to a valid domain."""
431
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
432
+ if min is not None or max is not None:
433
+ x = np.clip(x, min, max)
434
+ return x
435
+
436
+
437
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
438
+ """Replace a number in a tensor with another number.
439
+
440
+ Args:
441
+ x (np.ndarray): The input tensor.
442
+ num (float): The number to replace.
443
+ to (float): The number to replace with.
444
+
445
+ Returns:
446
+ np.ndarray: The tensor with the number replaced.
447
+ """
448
+ return np.where(x == num, to, x)
449
+
450
+
451
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
452
+ """Guard the power operation to a valid domain."""
453
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
454
+
fn_gen/ones_naive/1/loss.png ADDED
fn_gen/ones_naive/1/quantization.png ADDED
fn_gen/ones_naive/10/distortion.png ADDED
fn_gen/ones_naive/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acosh(_0*x)/_s
2
+ cosh(_s*x)/_0
fn_gen/ones_naive/10/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/10/loss.png ADDED
fn_gen/ones_naive/10/quantization.png ADDED
fn_gen/ones_naive/11/distortion.png ADDED
fn_gen/ones_naive/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/ones_naive/11/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/11/loss.png ADDED
fn_gen/ones_naive/11/quantization.png ADDED
fn_gen/ones_naive/12/distortion.png ADDED
fn_gen/ones_naive/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sinh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 + 1))/_0
fn_gen/ones_naive/12/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/12/loss.png ADDED
fn_gen/ones_naive/12/quantization.png ADDED
fn_gen/ones_naive/13/distortion.png ADDED
fn_gen/ones_naive/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/ones_naive/13/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/13/loss.png ADDED
fn_gen/ones_naive/13/quantization.png ADDED
fn_gen/ones_naive/14/distortion.png ADDED
fn_gen/ones_naive/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sqrt(_0*x)/_s
2
+ _s**2*x**2/_0
fn_gen/ones_naive/14/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/14/loss.png ADDED
fn_gen/ones_naive/14/quantization.png ADDED
fn_gen/ones_naive/15/distortion.png ADDED
fn_gen/ones_naive/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sin(_0*x)/_s
2
+ asin(_s*x)/_0
fn_gen/ones_naive/15/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sin((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.asin(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sin((_0 * x)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arcsin(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/15/loss.png ADDED
fn_gen/ones_naive/15/quantization.png ADDED
fn_gen/ones_naive/16/distortion.png ADDED
fn_gen/ones_naive/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/ones_naive/16/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/16/loss.png ADDED
fn_gen/ones_naive/16/quantization.png ADDED
fn_gen/ones_naive/17/distortion.png ADDED
fn_gen/ones_naive/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atan(_0*x)/_s
2
+ tan(_s*x)/_0
fn_gen/ones_naive/17/fn.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ params = {
23
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
24
+ }
25
+ params['_s'] = init_ones(x, params=params, qtz_func=quantization, **kwargs)
26
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
27
+
28
+ if 'post_init_hook' in kwargs:
29
+ kwargs['post_init_hook'](parameters=params)
30
+
31
+
32
+ if 'post_train_hook' in kwargs:
33
+ kwargs['post_train_hook'](parameters=params)
34
+
35
+ return params
36
+
37
+
38
+ ############### Numpy Qtz ###############
39
+
40
+
41
+ def np_quantization(x, _0, _s):
42
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x)))
43
+
44
+
45
+ def np_dequantization(x, _0, _s):
46
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0)))
47
+
48
+
49
+ def fit_func(x, _0, _s):
50
+ x_ = np_quantization(x, _0, _s)
51
+ x_ = np_dequantization(x_, _0, _s)
52
+ return x_
53
+
54
+
55
+
56
+ ############### HELPERS ###############
57
+
58
+ def domain_guard(
59
+ x: torch.Tensor,
60
+ min: float = None,
61
+ max: float = None,
62
+ posinf: float = None,
63
+ neginf: float = None,
64
+ nan: float = None
65
+ ) -> torch.Tensor:
66
+ """Guard a tensor to a valid domain."""
67
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
68
+ if min is not None or max is not None:
69
+ x = torch.clamp(x, min=min, max=max)
70
+ return x
71
+
72
+
73
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
74
+ """Replace a number in a tensor with another number.
75
+
76
+ Args:
77
+ x (torch.Tensor): The input tensor.
78
+ num (float): The number to replace.
79
+ to (float): The number to replace with.
80
+
81
+ Returns:
82
+ torch.Tensor: The tensor with the number replaced.
83
+ """
84
+ return torch.where(x == num, to, x)
85
+
86
+
87
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
88
+ """Guard the power operation to a valid domain."""
89
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
90
+
91
+
92
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
93
+ val = torch.amin(x, dim=1)
94
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
95
+
96
+
97
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
98
+ val = torch.amin(x, dim=1)
99
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
100
+
101
+
102
+ def init_space_search(
103
+ x: torch.Tensor,
104
+ **kwargs: Dict[str, Any],
105
+ ) -> torch.Tensor:
106
+
107
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
108
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
109
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
110
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
111
+
112
+ def _search_param(tensors: List[torch.tensor], n_params):
113
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
114
+ torch_tensors = torch.stack(tensors)
115
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
116
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
117
+ mean = torch.mean(torch_tensors, dim=0)
118
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
119
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
120
+
121
+ def _calc(x, qtz_func, deqtz_func, **params):
122
+ x_ = x.transpose(0, 1)
123
+ x_ = qtz_func(x=x_, **params)
124
+ x_ = deqtz_func(x=x_, **params)
125
+ x_ = x_.transpose(0, 1)
126
+ return x_
127
+
128
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
129
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
130
+ assert "params_list" in kwargs, "params list must be provided."
131
+ assert "param" in kwargs, "param must be provided."
132
+
133
+ qtz_func = kwargs.get('qtz_func')
134
+ deqtz_func = kwargs.get('deqtz_func')
135
+ params_list = kwargs.get('params_list')
136
+ param = kwargs.get('param')
137
+
138
+ n_runs = 50 # Number of runs to try to find the best parameters
139
+ n_random_params = 50 # Number of random parameters to generate
140
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
141
+ max_initial = 10000 # Maximum value to initialize the parameters
142
+
143
+ # Initializes the parameters
144
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
145
+ params = _build_initial_param(x, max_initial, n_random_params)
146
+
147
+ # Performs the search
148
+ for _ in range(n_runs):
149
+
150
+ best_params = []
151
+ for param_ in params:
152
+ try:
153
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
154
+ loss_ones = nn.MSELoss()(x, x_)
155
+
156
+ if len(best_params) < n_best_to_pick:
157
+ best_params.append((param_, loss_ones.item()))
158
+ best_params = sorted(best_params, key=lambda x: x[1])
159
+ elif loss_ones < best_params[-1][1]:
160
+ best_params[-1] = (param_, loss_ones.item())
161
+ best_params = sorted(best_params, key=lambda x: x[1])
162
+
163
+ except Exception: # The parameters might not be valid for the function's domain
164
+ continue
165
+
166
+ # Generates new parameters around the mean
167
+ params = _search_param([p for p, _ in best_params], n_random_params)
168
+
169
+ # Checks if the best parameter is better than the init_ones
170
+ p_ones = init_ones(x, **kwargs)
171
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
172
+ loss_ones = nn.MSELoss()(x, x_)
173
+
174
+ # Checks if the best parameter is better than the init_rand
175
+ p_rand = init_rand(x, **kwargs)
176
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
177
+ loss_rand = nn.MSELoss()(x, x_)
178
+
179
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
180
+ return p_rand
181
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
182
+ return p_ones
183
+ else:
184
+ return best_params[0][0]
185
+
186
+
187
+ def init_linear_scale( # Symmetric scale. From the study folder
188
+ x: torch.Tensor,
189
+ **kwargs: Dict[str, Any],
190
+ ) -> torch.Tensor:
191
+ assert "bits" in kwargs, "bits must be provided."
192
+ assert "params" in kwargs, "params must be provided."
193
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
194
+
195
+ bits = kwargs.get('bits')
196
+ params = kwargs.get('params')
197
+ qtz_func = kwargs.get('qtz_func')
198
+
199
+ x_ = x.transpose(0, 1)
200
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
201
+ x_ = x_.transpose(0, 1)
202
+
203
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
204
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
205
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
206
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
207
+
208
+ eps = torch.finfo(torch.float32).eps
209
+
210
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
211
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
212
+
213
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
214
+
215
+ # Introduces some noise in scale
216
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
217
+ scale = scale + 0.01 * torch.randn_like(scale)
218
+ return scale
219
+
220
+
221
+ def init_non_linear_regression_fit(
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+
226
+ assert "params_list" in kwargs, "params list must be provided."
227
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
228
+ assert "p0" in kwargs, "p0 must be provided."
229
+ np_fit_func = kwargs.get('np_fit_func')
230
+ params_list = kwargs.get('params_list')
231
+ p0 = kwargs.get('p0')
232
+
233
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
234
+ popt, _ = curve_fit(
235
+ func,
236
+ xdata,
237
+ ydata,
238
+ maxfev=1000,
239
+ p0=p0,
240
+ method='lm'
241
+ )
242
+ return popt
243
+
244
+ # 1. Needs to convert the torch tensor to numpy tensor
245
+ xdata = x.cpu().numpy()
246
+
247
+ # 2. Sorts the data so that it makes it easier to fit to it
248
+ sorted_xdata = np.sort(xdata, axis=-1)
249
+
250
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
251
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
252
+
253
+ # 3. Finds the best parameters for each channel
254
+ try:
255
+ params = []
256
+ for i in range(sorted_xdata.shape[0]):
257
+ xdata_ = sorted_xdata[i]
258
+ p0_ = [p0[p][i] for p in params_list]
259
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
260
+ params.append(ch_params)
261
+
262
+ # 4. Builds the parameters
263
+ result = {}
264
+ for i, p in enumerate(params_list):
265
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
266
+
267
+ return result
268
+
269
+ except ValueError as e:
270
+ print(f"Could not fit the function with error: {e}")
271
+ print(f"Using fallback result...")
272
+ return {
273
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
274
+ }
275
+
276
+
277
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
278
+ val = torch.amin(x, dim=1)
279
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
280
+
281
+
282
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
283
+ # Calculate the original minimum and maximum values
284
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
285
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
286
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
287
+
288
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
289
+ return torch.ones_like(x_min)
290
+
291
+ # Calculate the scale factor
292
+ scale = (_max - _min) / (x_max - x_min)
293
+ return scale
294
+
295
+
296
+
297
+ ############## Quant ###############
298
+
299
+ @torch.enable_grad()
300
+ def learn_parameters(
301
+ x: torch.Tensor,
302
+ params: Dict[str, nn.Parameter],
303
+ qtz_func: nn.Module,
304
+ deqtz_func: nn.Module,
305
+ bits: int,
306
+ target_dtype: torch.dtype,
307
+ epochs: int = 1000,
308
+ early_stop: bool = True,
309
+ do_report: bool = False
310
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
311
+
312
+ # Requires gradients in the parameters
313
+ for p in params.values():
314
+ p.requires_grad = True
315
+ p.grad = None
316
+
317
+ param_keys = list(params.keys())
318
+ param_values = list(params.values())
319
+
320
+ # Defines optimizer and loss function
321
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
322
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
323
+ loss_fn = nn.MSELoss()
324
+
325
+ # Contains the best loss and the best parameters
326
+ best_loss = float("inf")
327
+ best_params = None
328
+
329
+ # Used to stop the search early
330
+ min_delta = 1e-7
331
+ acc_loss = []
332
+ percent_epochs_before_stop = 0.1
333
+
334
+ for i in range(epochs):
335
+ optimizer.zero_grad()
336
+
337
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
338
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
339
+ loss = loss_fn(x, dequant)
340
+
341
+ if loss.isnan() or loss.isinf():
342
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
343
+
344
+ loss.backward()
345
+ optimizer.step()
346
+ scheduler.step()
347
+
348
+ acc_loss.append(loss.item())
349
+
350
+ # Reports loss every 10 steps
351
+ if i % 10 == 0 and do_report:
352
+ print(f"Epoch {i}: Loss {loss.item()}")
353
+
354
+ # Optimizes the parameter search by storing the best loss and the parameters
355
+ if loss.item() < best_loss:
356
+ best_loss = loss.item()
357
+ best_params = copy.deepcopy({
358
+ k: v for k, v in params.items() if k in param_keys
359
+ })
360
+
361
+ # We also stop the search if the loss has not considerably during the last 10% epochs
362
+ if early_stop:
363
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
364
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
365
+ break
366
+
367
+ # No longer requires gradients in the parameters
368
+ for p in best_params.values():
369
+ p.requires_grad = False
370
+ p.grad = None
371
+
372
+ if do_report:
373
+ return best_params, acc_loss
374
+ else:
375
+ return best_params
376
+
377
+
378
+ def quantize(
379
+ x: torch.Tensor,
380
+ params: Dict[str, nn.Parameter],
381
+ func: nn.Module,
382
+ bits: int,
383
+ target_dtype: torch.dtype = torch.int8
384
+ ) -> torch.Tensor:
385
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
386
+ x = x.transpose(0, 1) # Aligns shapes
387
+ x = func(x=x, **params)
388
+ x = x.transpose(0, 1)
389
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
390
+ return x
391
+
392
+
393
+ def dequantize(
394
+ x: torch.Tensor,
395
+ params: Dict[str, nn.Parameter],
396
+ func: nn.Module,
397
+ bits: int,
398
+ out_dtype: torch.dtype
399
+ ) -> torch.Tensor:
400
+ x = x.to(dtype=out_dtype)
401
+ x = x.transpose(0, 1)
402
+ x = func(x=x, **params)
403
+ x = x.transpose(0, 1)
404
+ return x
405
+
406
+
407
+ def round_func_BPDA(input):
408
+ # This is equivalent to replacing round function (non-differentiable) with
409
+ # an identity function (differentiable) only when backward.
410
+ forward_value = torch.round(input)
411
+ out = input.clone()
412
+ out.data = forward_value.data
413
+ return out
414
+
415
+
416
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
417
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
418
+
419
+
420
+
421
+ ############## Numpy ###############
422
+
423
+ def np_domain_guard(
424
+ x: np.ndarray,
425
+ min: float = None,
426
+ max: float = None,
427
+ posinf: float = None,
428
+ neginf: float = None,
429
+ nan: float = None
430
+ ) -> np.ndarray:
431
+ """Guard a tensor to a valid domain."""
432
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
433
+ if min is not None or max is not None:
434
+ x = np.clip(x, min, max)
435
+ return x
436
+
437
+
438
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
439
+ """Replace a number in a tensor with another number.
440
+
441
+ Args:
442
+ x (np.ndarray): The input tensor.
443
+ num (float): The number to replace.
444
+ to (float): The number to replace with.
445
+
446
+ Returns:
447
+ np.ndarray: The tensor with the number replaced.
448
+ """
449
+ return np.where(x == num, to, x)
450
+
451
+
452
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
453
+ """Guard the power operation to a valid domain."""
454
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
455
+
fn_gen/ones_naive/17/loss.png ADDED
fn_gen/ones_naive/17/quantization.png ADDED