5dimension commited on
Commit
9714aa3
·
verified ·
1 Parent(s): 28ecb3e

Initial commit: sentinel_quantization.py

Browse files
Files changed (1) hide show
  1. sentinel_quantization.py +188 -0
sentinel_quantization.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================================
3
+ SENTINEL QUANTIZATION
4
+ ================================================================================
5
+
6
+ Theory: The attracting fixed point C₁ ≈ −0.007994021805953 of the iteration
7
+ F(z_{k+1}) = F(z_k) is a natural quantization center.
8
+
9
+ Key Innovation: Use Sentinel dynamical properties for model quantization:
10
+ - Attracting fixed point C₁ as quantization zero-point
11
+ - Basin boundary C₂ as precision threshold
12
+ - Gradient Axiom (1/e) as quantization scale
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import numpy as np
18
+ from typing import Dict, Tuple
19
+
20
+ class SentinelQuantizer:
21
+ """
22
+ Sentinel-aware quantization using dynamical constants.
23
+
24
+ Quantization formula:
25
+ q = round((w - C₁) / scale)
26
+ scale = max(|w|) · (1/e) # Sentinel scale from gradient axiom
27
+
28
+ where C₁ = −0.007994021805953 is the attracting fixed point.
29
+ """
30
+
31
+ C1 = -0.007994021805953 # Attracting fixed point
32
+ INV_E = 1.0 / np.e # Gradient axiom limit
33
+
34
+ def __init__(self, bits: int = 8):
35
+ self.bits = bits
36
+ self.qmin = -(2 ** (bits - 1))
37
+ self.qmax = 2 ** (bits - 1) - 1
38
+
39
+ def find_scale(self, tensor: torch.Tensor) -> float:
40
+ """Find optimal quantization scale using Sentinel principle."""
41
+ # Scale = max(|w|) · (1/e)
42
+ # This ensures the quantized range maps to the "stable basin"
43
+ max_val = tensor.abs().max().item()
44
+ scale = max_val * self.INV_E
45
+ return max(scale, 1e-8)
46
+
47
+ def quantize(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
48
+ """
49
+ Quantize tensor to int8 (or specified bits).
50
+
51
+ Returns quantized tensor and scale for dequantization.
52
+ """
53
+ scale = self.find_scale(tensor)
54
+
55
+ # Shift by C₁ (attracting fixed point as zero-point)
56
+ shifted = tensor - self.C1
57
+
58
+ # Quantize
59
+ quantized = torch.round(shifted / scale)
60
+ quantized = torch.clamp(quantized, self.qmin, self.qmax)
61
+
62
+ return quantized, scale
63
+
64
+ def dequantize(self, quantized: torch.Tensor, scale: float) -> torch.Tensor:
65
+ """Dequantize back to float."""
66
+ return quantized * scale + self.C1
67
+
68
+ def quantize_model(self, model: nn.Module) -> Dict[str, Tuple[torch.Tensor, float]]:
69
+ """Quantize all parameters of a model."""
70
+ quantized_params = {}
71
+
72
+ for name, param in model.named_parameters():
73
+ if param.requires_grad:
74
+ q, scale = self.quantize(param.data)
75
+ quantized_params[name] = (q.to(torch.int8), scale)
76
+
77
+ return quantized_params
78
+
79
+ def dequantize_model(self, quantized_params: Dict) -> Dict[str, torch.Tensor]:
80
+ """Dequantize all parameters."""
81
+ dequantized = {}
82
+ for name, (q, scale) in quantized_params.items():
83
+ dequantized[name] = self.dequantize(q.float(), scale)
84
+ return dequantized
85
+
86
+
87
+ class SentinelQuantizedLinear(nn.Module):
88
+ """Linear layer with Sentinel-aware quantization."""
89
+
90
+ def __init__(self, in_features: int, out_features: int, bits: int = 8):
91
+ super().__init__()
92
+ self.in_features = in_features
93
+ self.out_features = out_features
94
+ self.bits = bits
95
+
96
+ self.weight = nn.Parameter(torch.randn(out_features, in_features))
97
+ self.bias = nn.Parameter(torch.zeros(out_features))
98
+
99
+ self.quantizer = SentinelQuantizer(bits)
100
+ self._register_quantization_params()
101
+
102
+ def _register_quantization_params(self):
103
+ """Register quantization scale as buffer."""
104
+ self.register_buffer('weight_scale', torch.tensor(1.0))
105
+ self.register_buffer('quantized_weight', torch.zeros_like(self.weight, dtype=torch.int8))
106
+
107
+ def quantize(self):
108
+ """Quantize weights in-place."""
109
+ q, scale = self.quantizer.quantize(self.weight.data)
110
+ self.quantized_weight.data = q
111
+ self.weight_scale = torch.tensor(scale)
112
+
113
+ def dequantize(self):
114
+ """Dequantize weights for computation."""
115
+ return self.quantizer.dequantize(self.quantized_weight.float(), self.weight_scale.item())
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ """Forward pass with dequantized weights."""
119
+ w = self.dequantize()
120
+ return F.linear(x, w, self.bias)
121
+
122
+
123
+ import torch.nn.functional as F
124
+
125
+
126
+ def demo_sentinel_quantization():
127
+ """Demo Sentinel quantization on synthetic model."""
128
+ print("=" * 70)
129
+ print(" SENTINEL QUANTIZATION")
130
+ print("=" * 70)
131
+
132
+ # Synthetic model
133
+ model = nn.Sequential(
134
+ nn.Linear(784, 256),
135
+ nn.ReLU(),
136
+ nn.Linear(256, 10)
137
+ )
138
+
139
+ # Original model stats
140
+ original_params = sum(p.numel() for p in model.parameters())
141
+ original_size = original_params * 4 # float32 = 4 bytes
142
+
143
+ print(f"\n--- Original Model ---")
144
+ print(f" Parameters: {original_params:,}")
145
+ print(f" Size (FP32): {original_size / 1024:.1f} KB")
146
+
147
+ # Quantize
148
+ quantizer = SentinelQuantizer(bits=8)
149
+ quantized_params = quantizer.quantize_model(model)
150
+
151
+ # Quantized model stats
152
+ quantized_size = sum(q.numel() * 1 + 4 for q, _ in quantized_params.values()) # int8 + float scale
153
+
154
+ print(f"\n--- Quantized Model (Sentinel-aware) ---")
155
+ print(f" Parameters: {sum(q.numel() for q, _ in quantized_params.values()):,}")
156
+ print(f" Size (INT8): {quantized_size / 1024:.1f} KB")
157
+ print(f" Compression ratio: {original_size / quantized_size:.2f}×")
158
+
159
+ # Verify dequantization quality
160
+ dequantized = quantizer.dequantize_model(quantized_params)
161
+
162
+ errors = []
163
+ for name, param in model.named_parameters():
164
+ if name in dequantized:
165
+ error = (param.data - dequantized[name]).abs().mean().item()
166
+ errors.append(error)
167
+
168
+ mean_error = np.mean(errors)
169
+ print(f"\n--- Dequantization Quality ---")
170
+ print(f" Mean absolute error: {mean_error:.6f}")
171
+ print(f" Attracting fixed point C₁: {SentinelQuantizer.C1:.12f}")
172
+ print(f" Sentinel scale factor (1/e): {SentinelQuantizer.INV_E:.6f}")
173
+
174
+ # Theoretical justification
175
+ print(f"\n--- Theoretical Justification ---")
176
+ print(f" C₁ = {SentinelQuantizer.C1:.12f} is the attracting fixed point")
177
+ print(f" All negative values converge to C₁ under F(z) iteration")
178
+ print(f" Using C₁ as zero-point: natural quantization center")
179
+ print(f" Scale = max(|w|)·(1/e): maps to stable basin")
180
+
181
+ print(f"\n{'='*70}")
182
+ print(f" SENTINEL QUANTIZATION: {original_size/quantized_size:.1f}× COMPRESSION")
183
+ print(f" WITH DYNAMICAL CONSTANTS AS QUANTIZATION PARAMETERS")
184
+ print(f"{'='*70}")
185
+
186
+
187
+ if __name__ == '__main__':
188
+ demo_sentinel_quantization()