Diogo-V commited on
Commit
ad92883
1 Parent(s): 8d7daa1

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/ones_noisy_scale/0/__pycache__/fn.cpython-310.pyc +0 -0
  2. fn_gen/ones_noisy_scale/0/distortion.png +0 -0
  3. fn_gen/ones_noisy_scale/0/expressions.txt +2 -0
  4. fn_gen/ones_noisy_scale/0/fn.py +489 -0
  5. fn_gen/ones_noisy_scale/0/loss.png +0 -0
  6. fn_gen/ones_noisy_scale/0/quantization.png +0 -0
  7. fn_gen/ones_noisy_scale/1/__pycache__/fn.cpython-310.pyc +0 -0
  8. fn_gen/ones_noisy_scale/1/distortion.png +0 -0
  9. fn_gen/ones_noisy_scale/1/expressions.txt +2 -0
  10. fn_gen/ones_noisy_scale/1/fn.py +489 -0
  11. fn_gen/ones_noisy_scale/1/loss.png +0 -0
  12. fn_gen/ones_noisy_scale/1/quantization.png +0 -0
  13. fn_gen/ones_noisy_scale/10/__pycache__/fn.cpython-310.pyc +0 -0
  14. fn_gen/ones_noisy_scale/10/distortion.png +0 -0
  15. fn_gen/ones_noisy_scale/10/expressions.txt +2 -0
  16. fn_gen/ones_noisy_scale/10/fn.py +488 -0
  17. fn_gen/ones_noisy_scale/10/loss.png +0 -0
  18. fn_gen/ones_noisy_scale/10/quantization.png +0 -0
  19. fn_gen/ones_noisy_scale/11/__pycache__/fn.cpython-310.pyc +0 -0
  20. fn_gen/ones_noisy_scale/11/distortion.png +0 -0
  21. fn_gen/ones_noisy_scale/11/expressions.txt +2 -0
  22. fn_gen/ones_noisy_scale/11/fn.py +489 -0
  23. fn_gen/ones_noisy_scale/11/loss.png +0 -0
  24. fn_gen/ones_noisy_scale/11/quantization.png +0 -0
  25. fn_gen/ones_noisy_scale/12/__pycache__/fn.cpython-310.pyc +0 -0
  26. fn_gen/ones_noisy_scale/12/distortion.png +0 -0
  27. fn_gen/ones_noisy_scale/12/expressions.txt +2 -0
  28. fn_gen/ones_noisy_scale/12/fn.py +489 -0
  29. fn_gen/ones_noisy_scale/12/loss.png +0 -0
  30. fn_gen/ones_noisy_scale/12/quantization.png +0 -0
  31. fn_gen/ones_noisy_scale/13/__pycache__/fn.cpython-310.pyc +0 -0
  32. fn_gen/ones_noisy_scale/13/distortion.png +0 -0
  33. fn_gen/ones_noisy_scale/13/expressions.txt +2 -0
  34. fn_gen/ones_noisy_scale/13/fn.py +489 -0
  35. fn_gen/ones_noisy_scale/13/loss.png +0 -0
  36. fn_gen/ones_noisy_scale/13/quantization.png +0 -0
  37. fn_gen/ones_noisy_scale/14/__pycache__/fn.cpython-310.pyc +0 -0
  38. fn_gen/ones_noisy_scale/14/distortion.png +0 -0
  39. fn_gen/ones_noisy_scale/14/expressions.txt +2 -0
  40. fn_gen/ones_noisy_scale/14/fn.py +489 -0
  41. fn_gen/ones_noisy_scale/14/loss.png +0 -0
  42. fn_gen/ones_noisy_scale/14/quantization.png +0 -0
  43. fn_gen/ones_noisy_scale/15/__pycache__/fn.cpython-310.pyc +0 -0
  44. fn_gen/ones_noisy_scale/15/distortion.png +0 -0
  45. fn_gen/ones_noisy_scale/15/expressions.txt +2 -0
  46. fn_gen/ones_noisy_scale/15/fn.py +489 -0
  47. fn_gen/ones_noisy_scale/15/loss.png +0 -0
  48. fn_gen/ones_noisy_scale/15/quantization.png +0 -0
  49. fn_gen/ones_noisy_scale/16/__pycache__/fn.cpython-310.pyc +0 -0
  50. fn_gen/ones_noisy_scale/16/distortion.png +0 -0
fn_gen/ones_noisy_scale/0/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
fn_gen/ones_noisy_scale/0/distortion.png ADDED
fn_gen/ones_noisy_scale/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ log(_0*x)/_s
2
+ exp(_s*x)/_0
fn_gen/ones_noisy_scale/0/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/0/loss.png ADDED
fn_gen/ones_noisy_scale/0/quantization.png ADDED
fn_gen/ones_noisy_scale/1/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
fn_gen/ones_noisy_scale/1/distortion.png ADDED
fn_gen/ones_noisy_scale/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ exp(_0*x)/_s
2
+ log(_s*x)/_0
fn_gen/ones_noisy_scale/1/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/1/loss.png ADDED
fn_gen/ones_noisy_scale/1/quantization.png ADDED
fn_gen/ones_noisy_scale/10/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
fn_gen/ones_noisy_scale/10/distortion.png ADDED
fn_gen/ones_noisy_scale/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**2/_s
2
+ sqrt(_s*x)
fn_gen/ones_noisy_scale/10/fn.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ }
58
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
59
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
60
+
61
+ if 'post_init_hook' in kwargs:
62
+ kwargs['post_init_hook'](parameters=params)
63
+
64
+
65
+ if 'post_train_hook' in kwargs:
66
+ kwargs['post_train_hook'](parameters=params)
67
+
68
+ return params
69
+
70
+
71
+ ############### Numpy Qtz ###############
72
+
73
+
74
+ def np_quantization(x, _s):
75
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
76
+
77
+
78
+ def np_dequantization(x, _s):
79
+ return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
80
+
81
+
82
+ def fit_func(x, _s):
83
+ x_ = np_quantization(x, _s)
84
+ x_ = np_dequantization(x_, _s)
85
+ return x_
86
+
87
+
88
+
89
+ ############### HELPERS ###############
90
+
91
+ def domain_guard(
92
+ x: torch.Tensor,
93
+ min: float = None,
94
+ max: float = None,
95
+ posinf: float = None,
96
+ neginf: float = None,
97
+ nan: float = None
98
+ ) -> torch.Tensor:
99
+ """Guard a tensor to a valid domain."""
100
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
101
+ if min is not None or max is not None:
102
+ x = torch.clamp(x, min=min, max=max)
103
+ return x
104
+
105
+
106
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
107
+ """Replace a number in a tensor with another number.
108
+
109
+ Args:
110
+ x (torch.Tensor): The input tensor.
111
+ num (float): The number to replace.
112
+ to (float): The number to replace with.
113
+
114
+ Returns:
115
+ torch.Tensor: The tensor with the number replaced.
116
+ """
117
+ return torch.where(x == num, to, x)
118
+
119
+
120
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
121
+ """Guard the power operation to a valid domain."""
122
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
123
+
124
+
125
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
126
+ val = torch.amin(x, dim=1)
127
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
128
+
129
+
130
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
131
+ val = torch.amin(x, dim=1)
132
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
133
+
134
+
135
+ def init_space_search(
136
+ x: torch.Tensor,
137
+ **kwargs: Dict[str, Any],
138
+ ) -> torch.Tensor:
139
+
140
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
141
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
142
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
143
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
144
+
145
+ def _search_param(tensors: List[torch.tensor], n_params):
146
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
147
+ torch_tensors = torch.stack(tensors)
148
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
149
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
150
+ mean = torch.mean(torch_tensors, dim=0)
151
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
152
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
153
+
154
+ def _calc(x, qtz_func, deqtz_func, **params):
155
+ x_ = x.transpose(0, 1)
156
+ x_ = qtz_func(x=x_, **params)
157
+ x_ = deqtz_func(x=x_, **params)
158
+ x_ = x_.transpose(0, 1)
159
+ return x_
160
+
161
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
162
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
163
+ assert "params_list" in kwargs, "params list must be provided."
164
+ assert "param" in kwargs, "param must be provided."
165
+
166
+ qtz_func = kwargs.get('qtz_func')
167
+ deqtz_func = kwargs.get('deqtz_func')
168
+ params_list = kwargs.get('params_list')
169
+ param = kwargs.get('param')
170
+
171
+ n_runs = 50 # Number of runs to try to find the best parameters
172
+ n_random_params = 50 # Number of random parameters to generate
173
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
174
+ max_initial = 10000 # Maximum value to initialize the parameters
175
+
176
+ # Initializes the parameters
177
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
178
+ params = _build_initial_param(x, max_initial, n_random_params)
179
+
180
+ # Performs the search
181
+ for _ in range(n_runs):
182
+
183
+ best_params = []
184
+ for param_ in params:
185
+ try:
186
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
187
+ loss_ones = nn.MSELoss()(x, x_)
188
+
189
+ if len(best_params) < n_best_to_pick:
190
+ best_params.append((param_, loss_ones.item()))
191
+ best_params = sorted(best_params, key=lambda x: x[1])
192
+ elif loss_ones < best_params[-1][1]:
193
+ best_params[-1] = (param_, loss_ones.item())
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+
196
+ except Exception: # The parameters might not be valid for the function's domain
197
+ continue
198
+
199
+ # Generates new parameters around the mean
200
+ params = _search_param([p for p, _ in best_params], n_random_params)
201
+
202
+ # Checks if the best parameter is better than the init_ones
203
+ p_ones = init_ones(x, **kwargs)
204
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
205
+ loss_ones = nn.MSELoss()(x, x_)
206
+
207
+ # Checks if the best parameter is better than the init_rand
208
+ p_rand = init_rand(x, **kwargs)
209
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
210
+ loss_rand = nn.MSELoss()(x, x_)
211
+
212
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
213
+ return p_rand
214
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
215
+ return p_ones
216
+ else:
217
+ return best_params[0][0]
218
+
219
+
220
+ def init_linear_scale( # Symmetric scale. From the study folder
221
+ x: torch.Tensor,
222
+ **kwargs: Dict[str, Any],
223
+ ) -> torch.Tensor:
224
+ assert "bits" in kwargs, "bits must be provided."
225
+ assert "params" in kwargs, "params must be provided."
226
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
227
+
228
+ bits = kwargs.get('bits')
229
+ params = kwargs.get('params')
230
+ qtz_func = kwargs.get('qtz_func')
231
+
232
+ x_ = x.transpose(0, 1)
233
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
234
+ x_ = x_.transpose(0, 1)
235
+
236
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
237
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
238
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
239
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
240
+
241
+ eps = torch.finfo(torch.float32).eps
242
+
243
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
244
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
245
+
246
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
247
+
248
+ # Introduces some noise in scale
249
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
250
+ scale = scale + 0.01 * torch.randn_like(scale)
251
+ return scale
252
+
253
+
254
+ def init_non_linear_regression_fit(
255
+ x: torch.Tensor,
256
+ **kwargs: Dict[str, Any],
257
+ ) -> torch.Tensor:
258
+
259
+ assert "params_list" in kwargs, "params list must be provided."
260
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
261
+ assert "p0" in kwargs, "p0 must be provided."
262
+ np_fit_func = kwargs.get('np_fit_func')
263
+ params_list = kwargs.get('params_list')
264
+ p0 = kwargs.get('p0')
265
+
266
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
267
+ popt, _ = curve_fit(
268
+ func,
269
+ xdata,
270
+ ydata,
271
+ maxfev=1000,
272
+ p0=p0,
273
+ method='lm'
274
+ )
275
+ return popt
276
+
277
+ # 1. Needs to convert the torch tensor to numpy tensor
278
+ xdata = x.cpu().numpy()
279
+
280
+ # 2. Sorts the data so that it makes it easier to fit to it
281
+ sorted_xdata = np.sort(xdata, axis=-1)
282
+
283
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
284
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
285
+
286
+ # 3. Finds the best parameters for each channel
287
+ try:
288
+ params = []
289
+ for i in range(sorted_xdata.shape[0]):
290
+ xdata_ = sorted_xdata[i]
291
+ p0_ = [p0[p][i] for p in params_list]
292
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
293
+ params.append(ch_params)
294
+
295
+ # 4. Builds the parameters
296
+ result = {}
297
+ for i, p in enumerate(params_list):
298
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
299
+
300
+ return result
301
+
302
+ except ValueError as e:
303
+ print(f"Could not fit the function with error: {e}")
304
+ print(f"Using fallback result...")
305
+ return {
306
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
307
+ }
308
+
309
+
310
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
311
+ val = torch.amin(x, dim=1)
312
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
313
+
314
+
315
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
316
+ # Calculate the original minimum and maximum values
317
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
318
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
319
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
320
+
321
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
322
+ return torch.ones_like(x_min)
323
+
324
+ # Calculate the scale factor
325
+ scale = (_max - _min) / (x_max - x_min)
326
+ return scale
327
+
328
+
329
+
330
+ ############## Quant ###############
331
+
332
+ @torch.enable_grad()
333
+ def learn_parameters(
334
+ x: torch.Tensor,
335
+ params: Dict[str, nn.Parameter],
336
+ qtz_func: nn.Module,
337
+ deqtz_func: nn.Module,
338
+ bits: int,
339
+ target_dtype: torch.dtype,
340
+ epochs: int = 1000,
341
+ early_stop: bool = True,
342
+ do_report: bool = False
343
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
344
+
345
+ # Requires gradients in the parameters
346
+ for p in params.values():
347
+ p.requires_grad = True
348
+ p.grad = None
349
+
350
+ param_keys = list(params.keys())
351
+ param_values = list(params.values())
352
+
353
+ # Defines optimizer and loss function
354
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
355
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
356
+ loss_fn = nn.MSELoss()
357
+
358
+ # Contains the best loss and the best parameters
359
+ best_loss = float("inf")
360
+ best_params = None
361
+
362
+ # Used to stop the search early
363
+ min_delta = 1e-7
364
+ acc_loss = []
365
+ percent_epochs_before_stop = 0.1
366
+
367
+ for i in range(epochs):
368
+ optimizer.zero_grad()
369
+
370
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
371
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
372
+ loss = loss_fn(x, dequant)
373
+
374
+ if loss.isnan() or loss.isinf():
375
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
376
+
377
+ loss.backward()
378
+ optimizer.step()
379
+ scheduler.step()
380
+
381
+ acc_loss.append(loss.item())
382
+
383
+ # Reports loss every 10 steps
384
+ if i % 10 == 0 and do_report:
385
+ print(f"Epoch {i}: Loss {loss.item()}")
386
+
387
+ # Optimizes the parameter search by storing the best loss and the parameters
388
+ if loss.item() < best_loss:
389
+ best_loss = loss.item()
390
+ best_params = copy.deepcopy({
391
+ k: v for k, v in params.items() if k in param_keys
392
+ })
393
+
394
+ # We also stop the search if the loss has not considerably during the last 10% epochs
395
+ if early_stop:
396
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
397
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
398
+ break
399
+
400
+ # No longer requires gradients in the parameters
401
+ for p in best_params.values():
402
+ p.requires_grad = False
403
+ p.grad = None
404
+
405
+ if do_report:
406
+ return best_params, acc_loss
407
+ else:
408
+ return best_params
409
+
410
+
411
+ def quantize(
412
+ x: torch.Tensor,
413
+ params: Dict[str, nn.Parameter],
414
+ func: nn.Module,
415
+ bits: int,
416
+ target_dtype: torch.dtype = torch.int8
417
+ ) -> torch.Tensor:
418
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
419
+ x = x.transpose(0, 1) # Aligns shapes
420
+ x = func(x=x, **params)
421
+ x = x.transpose(0, 1)
422
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
423
+ return x
424
+
425
+
426
+ def dequantize(
427
+ x: torch.Tensor,
428
+ params: Dict[str, nn.Parameter],
429
+ func: nn.Module,
430
+ bits: int,
431
+ out_dtype: torch.dtype
432
+ ) -> torch.Tensor:
433
+ x = x.to(dtype=out_dtype)
434
+ x = x.transpose(0, 1)
435
+ x = func(x=x, **params)
436
+ x = x.transpose(0, 1)
437
+ return x
438
+
439
+
440
+ def round_func_BPDA(input):
441
+ # This is equivalent to replacing round function (non-differentiable) with
442
+ # an identity function (differentiable) only when backward.
443
+ forward_value = torch.round(input)
444
+ out = input.clone()
445
+ out.data = forward_value.data
446
+ return out
447
+
448
+
449
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
450
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
451
+
452
+
453
+
454
+ ############## Numpy ###############
455
+
456
+ def np_domain_guard(
457
+ x: np.ndarray,
458
+ min: float = None,
459
+ max: float = None,
460
+ posinf: float = None,
461
+ neginf: float = None,
462
+ nan: float = None
463
+ ) -> np.ndarray:
464
+ """Guard a tensor to a valid domain."""
465
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
466
+ if min is not None or max is not None:
467
+ x = np.clip(x, min, max)
468
+ return x
469
+
470
+
471
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
472
+ """Replace a number in a tensor with another number.
473
+
474
+ Args:
475
+ x (np.ndarray): The input tensor.
476
+ num (float): The number to replace.
477
+ to (float): The number to replace with.
478
+
479
+ Returns:
480
+ np.ndarray: The tensor with the number replaced.
481
+ """
482
+ return np.where(x == num, to, x)
483
+
484
+
485
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
486
+ """Guard the power operation to a valid domain."""
487
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
488
+
fn_gen/ones_noisy_scale/10/loss.png ADDED
fn_gen/ones_noisy_scale/10/quantization.png ADDED
fn_gen/ones_noisy_scale/11/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
fn_gen/ones_noisy_scale/11/distortion.png ADDED
fn_gen/ones_noisy_scale/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/ones_noisy_scale/11/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/11/loss.png ADDED
fn_gen/ones_noisy_scale/11/quantization.png ADDED
fn_gen/ones_noisy_scale/12/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
fn_gen/ones_noisy_scale/12/distortion.png ADDED
fn_gen/ones_noisy_scale/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cos(_0*x)/_s
2
+ acos(_s*x)/_0
fn_gen/ones_noisy_scale/12/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/12/loss.png ADDED
fn_gen/ones_noisy_scale/12/quantization.png ADDED
fn_gen/ones_noisy_scale/13/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
fn_gen/ones_noisy_scale/13/distortion.png ADDED
fn_gen/ones_noisy_scale/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sinh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 + 1))/_0
fn_gen/ones_noisy_scale/13/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/13/loss.png ADDED
fn_gen/ones_noisy_scale/13/quantization.png ADDED
fn_gen/ones_noisy_scale/14/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
fn_gen/ones_noisy_scale/14/distortion.png ADDED
fn_gen/ones_noisy_scale/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/ones_noisy_scale/14/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/14/loss.png ADDED
fn_gen/ones_noisy_scale/14/quantization.png ADDED
fn_gen/ones_noisy_scale/15/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
fn_gen/ones_noisy_scale/15/distortion.png ADDED
fn_gen/ones_noisy_scale/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cosh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 - 1))/_0
fn_gen/ones_noisy_scale/15/fn.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cosh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(-1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+
66
+ if 'post_train_hook' in kwargs:
67
+ kwargs['post_train_hook'](parameters=params)
68
+
69
+ return params
70
+
71
+
72
+ ############### Numpy Qtz ###############
73
+
74
+
75
+ def np_quantization(x, _0, _s):
76
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cosh((_0 * x)))
77
+
78
+
79
+ def np_dequantization(x, _0, _s):
80
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(-1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
81
+
82
+
83
+ def fit_func(x, _0, _s):
84
+ x_ = np_quantization(x, _0, _s)
85
+ x_ = np_dequantization(x_, _0, _s)
86
+ return x_
87
+
88
+
89
+
90
+ ############### HELPERS ###############
91
+
92
+ def domain_guard(
93
+ x: torch.Tensor,
94
+ min: float = None,
95
+ max: float = None,
96
+ posinf: float = None,
97
+ neginf: float = None,
98
+ nan: float = None
99
+ ) -> torch.Tensor:
100
+ """Guard a tensor to a valid domain."""
101
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
102
+ if min is not None or max is not None:
103
+ x = torch.clamp(x, min=min, max=max)
104
+ return x
105
+
106
+
107
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
108
+ """Replace a number in a tensor with another number.
109
+
110
+ Args:
111
+ x (torch.Tensor): The input tensor.
112
+ num (float): The number to replace.
113
+ to (float): The number to replace with.
114
+
115
+ Returns:
116
+ torch.Tensor: The tensor with the number replaced.
117
+ """
118
+ return torch.where(x == num, to, x)
119
+
120
+
121
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
122
+ """Guard the power operation to a valid domain."""
123
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
124
+
125
+
126
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
127
+ val = torch.amin(x, dim=1)
128
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
129
+
130
+
131
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
132
+ val = torch.amin(x, dim=1)
133
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
134
+
135
+
136
+ def init_space_search(
137
+ x: torch.Tensor,
138
+ **kwargs: Dict[str, Any],
139
+ ) -> torch.Tensor:
140
+
141
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
142
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
143
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
144
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
145
+
146
+ def _search_param(tensors: List[torch.tensor], n_params):
147
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
148
+ torch_tensors = torch.stack(tensors)
149
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
150
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
151
+ mean = torch.mean(torch_tensors, dim=0)
152
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
153
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
154
+
155
+ def _calc(x, qtz_func, deqtz_func, **params):
156
+ x_ = x.transpose(0, 1)
157
+ x_ = qtz_func(x=x_, **params)
158
+ x_ = deqtz_func(x=x_, **params)
159
+ x_ = x_.transpose(0, 1)
160
+ return x_
161
+
162
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
163
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
164
+ assert "params_list" in kwargs, "params list must be provided."
165
+ assert "param" in kwargs, "param must be provided."
166
+
167
+ qtz_func = kwargs.get('qtz_func')
168
+ deqtz_func = kwargs.get('deqtz_func')
169
+ params_list = kwargs.get('params_list')
170
+ param = kwargs.get('param')
171
+
172
+ n_runs = 50 # Number of runs to try to find the best parameters
173
+ n_random_params = 50 # Number of random parameters to generate
174
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
175
+ max_initial = 10000 # Maximum value to initialize the parameters
176
+
177
+ # Initializes the parameters
178
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
179
+ params = _build_initial_param(x, max_initial, n_random_params)
180
+
181
+ # Performs the search
182
+ for _ in range(n_runs):
183
+
184
+ best_params = []
185
+ for param_ in params:
186
+ try:
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
188
+ loss_ones = nn.MSELoss()(x, x_)
189
+
190
+ if len(best_params) < n_best_to_pick:
191
+ best_params.append((param_, loss_ones.item()))
192
+ best_params = sorted(best_params, key=lambda x: x[1])
193
+ elif loss_ones < best_params[-1][1]:
194
+ best_params[-1] = (param_, loss_ones.item())
195
+ best_params = sorted(best_params, key=lambda x: x[1])
196
+
197
+ except Exception: # The parameters might not be valid for the function's domain
198
+ continue
199
+
200
+ # Generates new parameters around the mean
201
+ params = _search_param([p for p, _ in best_params], n_random_params)
202
+
203
+ # Checks if the best parameter is better than the init_ones
204
+ p_ones = init_ones(x, **kwargs)
205
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
206
+ loss_ones = nn.MSELoss()(x, x_)
207
+
208
+ # Checks if the best parameter is better than the init_rand
209
+ p_rand = init_rand(x, **kwargs)
210
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
211
+ loss_rand = nn.MSELoss()(x, x_)
212
+
213
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
214
+ return p_rand
215
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
216
+ return p_ones
217
+ else:
218
+ return best_params[0][0]
219
+
220
+
221
+ def init_linear_scale( # Symmetric scale. From the study folder
222
+ x: torch.Tensor,
223
+ **kwargs: Dict[str, Any],
224
+ ) -> torch.Tensor:
225
+ assert "bits" in kwargs, "bits must be provided."
226
+ assert "params" in kwargs, "params must be provided."
227
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
228
+
229
+ bits = kwargs.get('bits')
230
+ params = kwargs.get('params')
231
+ qtz_func = kwargs.get('qtz_func')
232
+
233
+ x_ = x.transpose(0, 1)
234
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
235
+ x_ = x_.transpose(0, 1)
236
+
237
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
238
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
239
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
240
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
241
+
242
+ eps = torch.finfo(torch.float32).eps
243
+
244
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
245
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
246
+
247
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
248
+
249
+ # Introduces some noise in scale
250
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
251
+ scale = scale + 0.01 * torch.randn_like(scale)
252
+ return scale
253
+
254
+
255
+ def init_non_linear_regression_fit(
256
+ x: torch.Tensor,
257
+ **kwargs: Dict[str, Any],
258
+ ) -> torch.Tensor:
259
+
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
262
+ assert "p0" in kwargs, "p0 must be provided."
263
+ np_fit_func = kwargs.get('np_fit_func')
264
+ params_list = kwargs.get('params_list')
265
+ p0 = kwargs.get('p0')
266
+
267
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
268
+ popt, _ = curve_fit(
269
+ func,
270
+ xdata,
271
+ ydata,
272
+ maxfev=1000,
273
+ p0=p0,
274
+ method='lm'
275
+ )
276
+ return popt
277
+
278
+ # 1. Needs to convert the torch tensor to numpy tensor
279
+ xdata = x.cpu().numpy()
280
+
281
+ # 2. Sorts the data so that it makes it easier to fit to it
282
+ sorted_xdata = np.sort(xdata, axis=-1)
283
+
284
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
285
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
286
+
287
+ # 3. Finds the best parameters for each channel
288
+ try:
289
+ params = []
290
+ for i in range(sorted_xdata.shape[0]):
291
+ xdata_ = sorted_xdata[i]
292
+ p0_ = [p0[p][i] for p in params_list]
293
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
294
+ params.append(ch_params)
295
+
296
+ # 4. Builds the parameters
297
+ result = {}
298
+ for i, p in enumerate(params_list):
299
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
300
+
301
+ return result
302
+
303
+ except ValueError as e:
304
+ print(f"Could not fit the function with error: {e}")
305
+ print(f"Using fallback result...")
306
+ return {
307
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
308
+ }
309
+
310
+
311
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
312
+ val = torch.amin(x, dim=1)
313
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
314
+
315
+
316
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
317
+ # Calculate the original minimum and maximum values
318
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
319
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
320
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
321
+
322
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
323
+ return torch.ones_like(x_min)
324
+
325
+ # Calculate the scale factor
326
+ scale = (_max - _min) / (x_max - x_min)
327
+ return scale
328
+
329
+
330
+
331
+ ############## Quant ###############
332
+
333
+ @torch.enable_grad()
334
+ def learn_parameters(
335
+ x: torch.Tensor,
336
+ params: Dict[str, nn.Parameter],
337
+ qtz_func: nn.Module,
338
+ deqtz_func: nn.Module,
339
+ bits: int,
340
+ target_dtype: torch.dtype,
341
+ epochs: int = 1000,
342
+ early_stop: bool = True,
343
+ do_report: bool = False
344
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
345
+
346
+ # Requires gradients in the parameters
347
+ for p in params.values():
348
+ p.requires_grad = True
349
+ p.grad = None
350
+
351
+ param_keys = list(params.keys())
352
+ param_values = list(params.values())
353
+
354
+ # Defines optimizer and loss function
355
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
356
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
357
+ loss_fn = nn.MSELoss()
358
+
359
+ # Contains the best loss and the best parameters
360
+ best_loss = float("inf")
361
+ best_params = None
362
+
363
+ # Used to stop the search early
364
+ min_delta = 1e-7
365
+ acc_loss = []
366
+ percent_epochs_before_stop = 0.1
367
+
368
+ for i in range(epochs):
369
+ optimizer.zero_grad()
370
+
371
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
372
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
373
+ loss = loss_fn(x, dequant)
374
+
375
+ if loss.isnan() or loss.isinf():
376
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
377
+
378
+ loss.backward()
379
+ optimizer.step()
380
+ scheduler.step()
381
+
382
+ acc_loss.append(loss.item())
383
+
384
+ # Reports loss every 10 steps
385
+ if i % 10 == 0 and do_report:
386
+ print(f"Epoch {i}: Loss {loss.item()}")
387
+
388
+ # Optimizes the parameter search by storing the best loss and the parameters
389
+ if loss.item() < best_loss:
390
+ best_loss = loss.item()
391
+ best_params = copy.deepcopy({
392
+ k: v for k, v in params.items() if k in param_keys
393
+ })
394
+
395
+ # We also stop the search if the loss has not considerably during the last 10% epochs
396
+ if early_stop:
397
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
398
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
399
+ break
400
+
401
+ # No longer requires gradients in the parameters
402
+ for p in best_params.values():
403
+ p.requires_grad = False
404
+ p.grad = None
405
+
406
+ if do_report:
407
+ return best_params, acc_loss
408
+ else:
409
+ return best_params
410
+
411
+
412
+ def quantize(
413
+ x: torch.Tensor,
414
+ params: Dict[str, nn.Parameter],
415
+ func: nn.Module,
416
+ bits: int,
417
+ target_dtype: torch.dtype = torch.int8
418
+ ) -> torch.Tensor:
419
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
420
+ x = x.transpose(0, 1) # Aligns shapes
421
+ x = func(x=x, **params)
422
+ x = x.transpose(0, 1)
423
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
424
+ return x
425
+
426
+
427
+ def dequantize(
428
+ x: torch.Tensor,
429
+ params: Dict[str, nn.Parameter],
430
+ func: nn.Module,
431
+ bits: int,
432
+ out_dtype: torch.dtype
433
+ ) -> torch.Tensor:
434
+ x = x.to(dtype=out_dtype)
435
+ x = x.transpose(0, 1)
436
+ x = func(x=x, **params)
437
+ x = x.transpose(0, 1)
438
+ return x
439
+
440
+
441
+ def round_func_BPDA(input):
442
+ # This is equivalent to replacing round function (non-differentiable) with
443
+ # an identity function (differentiable) only when backward.
444
+ forward_value = torch.round(input)
445
+ out = input.clone()
446
+ out.data = forward_value.data
447
+ return out
448
+
449
+
450
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
451
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
452
+
453
+
454
+
455
+ ############## Numpy ###############
456
+
457
+ def np_domain_guard(
458
+ x: np.ndarray,
459
+ min: float = None,
460
+ max: float = None,
461
+ posinf: float = None,
462
+ neginf: float = None,
463
+ nan: float = None
464
+ ) -> np.ndarray:
465
+ """Guard a tensor to a valid domain."""
466
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
467
+ if min is not None or max is not None:
468
+ x = np.clip(x, min, max)
469
+ return x
470
+
471
+
472
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
473
+ """Replace a number in a tensor with another number.
474
+
475
+ Args:
476
+ x (np.ndarray): The input tensor.
477
+ num (float): The number to replace.
478
+ to (float): The number to replace with.
479
+
480
+ Returns:
481
+ np.ndarray: The tensor with the number replaced.
482
+ """
483
+ return np.where(x == num, to, x)
484
+
485
+
486
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
487
+ """Guard the power operation to a valid domain."""
488
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
489
+
fn_gen/ones_noisy_scale/15/loss.png ADDED
fn_gen/ones_noisy_scale/15/quantization.png ADDED
fn_gen/ones_noisy_scale/16/__pycache__/fn.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
fn_gen/ones_noisy_scale/16/distortion.png ADDED