AbstractPhil commited on
Commit
275fb95
Β·
verified Β·
1 Parent(s): 26fbfd4

Create trainer.py

Browse files
Files changed (1) hide show
  1. v18_johanna_curriculum/trainer.py +631 -0
v18_johanna_curriculum/trainer.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Johanna-Tiny Full Battery Diagnostic
3
+ ======================================
4
+ Comprehensive analysis of the curriculum-trained 16-type noise model.
5
+
6
+ Tests:
7
+ 1. Per-type MSE (100 samples each, full eval)
8
+ 2. Per-type byte accuracy (discrete reconstruction precision)
9
+ 3. Geometric fingerprint per noise type (Sβ‚€, ratio, erank, CV)
10
+ 4. Cross-type omega token similarity (cosine distance matrix)
11
+ 5. Spectrum analysis per type (which modes carry which distributions)
12
+ 6. Reconstruction visualization grid (all 16 types)
13
+ 7. Zero-shot transfer: real images through noise-trained model
14
+ 8. Zero-shot transfer: text bytes through noise-trained model
15
+ 9. Piecemeal 256β†’64: can tiny do tiled reconstruction?
16
+ 10. Noise-to-noise: encode type A, does it look like type A?
17
+ 11. Effective capacity: what percentage of the signal survives?
18
+ 12. Alpha profile: what did the cross-attention learn?
19
+ """
20
+
21
+ import os
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torchvision.transforms as T
26
+ import math
27
+ import time
28
+ import numpy as np
29
+ import json
30
+ from collections import defaultdict
31
+
32
+ # ── Load model ───────────────────────────────────────────────────
33
+
34
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+
36
+ # Option 1: Load from local checkpoint
37
+ CHECKPOINT = '/content/checkpoints/best.pt'
38
+ # Option 2: Load from HuggingFace
39
+ HF_CHECKPOINT = 'AbstractPhil/geolip-SVAE'
40
+ HF_FILE = 'v18_johanna_curriculum/checkpoints/epoch_0300.pt'
41
+
42
+
43
+ def load_model():
44
+ """Load model from local or HF checkpoint."""
45
+ from huggingface_hub import hf_hub_download
46
+
47
+ # Try local first
48
+ if os.path.exists(CHECKPOINT):
49
+ path = CHECKPOINT
50
+ print(f" Loading local: {path}")
51
+ else:
52
+ path = hf_hub_download(repo_id=HF_CHECKPOINT, filename=HF_FILE, repo_type="model")
53
+ print(f" Loading HF: {HF_FILE}")
54
+
55
+ ckpt = torch.load(path, map_location='cpu', weights_only=False)
56
+ cfg = ckpt['config']
57
+ print(f" Epoch: {ckpt.get('epoch')}, MSE: {ckpt.get('test_mse', '?')}")
58
+ print(f" Config: {cfg}")
59
+
60
+ # Build model inline (same architecture)
61
+ from types import SimpleNamespace
62
+
63
+ class BoundarySmooth(nn.Module):
64
+ def __init__(self, channels=3, mid=16):
65
+ super().__init__()
66
+ self.net = nn.Sequential(nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(),
67
+ nn.Conv2d(mid, channels, 3, padding=1))
68
+ nn.init.zeros_(self.net[-1].weight); nn.init.zeros_(self.net[-1].bias)
69
+ def forward(self, x): return x + self.net(x)
70
+
71
+ class SpectralCrossAttention(nn.Module):
72
+ def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0):
73
+ super().__init__()
74
+ self.n_heads = n_heads; self.head_dim = D // n_heads
75
+ self.max_alpha = max_alpha
76
+ self.qkv = nn.Linear(D, 3*D); self.out_proj = nn.Linear(D, D)
77
+ self.norm = nn.LayerNorm(D); self.scale = self.head_dim**-0.5
78
+ self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
79
+ @property
80
+ def alpha(self): return self.max_alpha * torch.sigmoid(self.alpha_logits)
81
+ def forward(self, S):
82
+ B, N, D = S.shape; S_n = self.norm(S)
83
+ qkv = self.qkv(S_n).reshape(B,N,3,self.n_heads,self.head_dim).permute(2,0,3,1,4)
84
+ q, k, v = qkv[0], qkv[1], qkv[2]
85
+ out = (((q @ k.transpose(-2,-1))*self.scale).softmax(-1) @ v).transpose(1,2).reshape(B,N,D)
86
+ return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * torch.tanh(self.out_proj(out)))
87
+
88
+ class PatchSVAE(nn.Module):
89
+ def __init__(self, V=256, D=16, ps=16, hidden=768, depth=4, n_cross=2):
90
+ super().__init__()
91
+ self.matrix_v, self.D, self.patch_size = V, D, ps
92
+ self.patch_dim = 3*ps*ps; self.mat_dim = V*D
93
+ self.enc_in = nn.Linear(self.patch_dim, hidden)
94
+ self.enc_blocks = nn.ModuleList([nn.Sequential(
95
+ nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
96
+ nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)])
97
+ self.enc_out = nn.Linear(hidden, self.mat_dim)
98
+ self.dec_in = nn.Linear(self.mat_dim, hidden)
99
+ self.dec_blocks = nn.ModuleList([nn.Sequential(
100
+ nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
101
+ nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)])
102
+ self.dec_out = nn.Linear(hidden, self.patch_dim)
103
+ nn.init.orthogonal_(self.enc_out.weight)
104
+ self.cross_attn = nn.ModuleList([
105
+ SpectralCrossAttention(D, n_heads=min(4,D)) for _ in range(n_cross)])
106
+ self.boundary_smooth = BoundarySmooth(channels=3, mid=16)
107
+
108
+ def _svd(self, A):
109
+ orig = A.dtype
110
+ with torch.amp.autocast('cuda', enabled=False):
111
+ A_d = A.double()
112
+ G = torch.bmm(A_d.transpose(1,2), A_d)
113
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
114
+ eig, V = torch.linalg.eigh(G)
115
+ eig = eig.flip(-1); V = V.flip(-1)
116
+ S = torch.sqrt(eig.clamp(min=1e-24))
117
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
118
+ Vh = V.transpose(-2,-1).contiguous()
119
+ return U.to(orig), S.to(orig), Vh.to(orig)
120
+
121
+ def encode_patches(self, patches):
122
+ B, N, _ = patches.shape
123
+ h = F.gelu(self.enc_in(patches.reshape(B*N,-1)))
124
+ for block in self.enc_blocks: h = h + block(h)
125
+ M = F.normalize(self.enc_out(h).reshape(B*N, self.matrix_v, self.D), dim=-1)
126
+ U, S, Vt = self._svd(M)
127
+ U = U.reshape(B,N,self.matrix_v,self.D); S = S.reshape(B,N,self.D)
128
+ Vt = Vt.reshape(B,N,self.D,self.D); M = M.reshape(B,N,self.matrix_v,self.D)
129
+ S_c = S
130
+ for layer in self.cross_attn: S_c = layer(S_c)
131
+ return {'U':U, 'S_orig':S, 'S':S_c, 'Vt':Vt, 'M':M}
132
+
133
+ def decode_patches(self, U, S, Vt):
134
+ B, N, V, D = U.shape
135
+ M_hat = torch.bmm(U.reshape(B*N,V,D)*S.reshape(B*N,D).unsqueeze(1), Vt.reshape(B*N,D,D))
136
+ h = F.gelu(self.dec_in(M_hat.reshape(B*N,-1)))
137
+ for block in self.dec_blocks: h = h + block(h)
138
+ return self.dec_out(h).reshape(B, N, -1)
139
+
140
+ def forward(self, images):
141
+ B, C, H, W = images.shape
142
+ ps = self.patch_size
143
+ gh, gw = H//ps, W//ps
144
+ p = images.reshape(B,C,gh,ps,gw,ps).permute(0,2,4,1,3,5).reshape(B,gh*gw,C*ps*ps)
145
+ svd = self.encode_patches(p)
146
+ dec = self.decode_patches(svd['U'], svd['S'], svd['Vt'])
147
+ dec = dec.reshape(B,gh,gw,3,ps,ps).permute(0,3,1,4,2,5).reshape(B,3,gh*ps,gw*ps)
148
+ return {'recon': self.boundary_smooth(dec), 'svd': svd, 'gh': gh, 'gw': gw}
149
+
150
+ @staticmethod
151
+ def effective_rank(S):
152
+ p = S / (S.sum(-1, keepdim=True)+1e-8); p = p.clamp(min=1e-8)
153
+ return (-(p * p.log()).sum(-1)).exp()
154
+
155
+ model = PatchSVAE(V=cfg['V'], D=cfg['D'], ps=cfg['patch_size'],
156
+ hidden=cfg['hidden'], depth=cfg['depth'],
157
+ n_cross=cfg['n_cross_layers'])
158
+ model.load_state_dict(ckpt['model_state_dict'], strict=True)
159
+ model = model.to(DEVICE).eval()
160
+ print(f" Loaded {sum(p.numel() for p in model.parameters()):,} params")
161
+ return model, cfg
162
+
163
+
164
+ # ── Noise Generators ─────────────────────────────────────────────
165
+
166
+ NOISE_NAMES = {
167
+ 0: 'gaussian', 1: 'uniform', 2: 'uniform_scaled', 3: 'poisson',
168
+ 4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse',
169
+ 8: 'block', 9: 'gradient', 10: 'checkerboard', 11: 'mixed',
170
+ 12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace',
171
+ }
172
+
173
+ def _pink(shape):
174
+ w = torch.randn(shape); S = torch.fft.rfft2(w)
175
+ h, ww = shape[-2], shape[-1]
176
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww//2+1)
177
+ fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
178
+ return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
179
+
180
+ def _brown(shape):
181
+ w = torch.randn(shape); S = torch.fft.rfft2(w)
182
+ h, ww = shape[-2], shape[-1]
183
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww//2+1)
184
+ fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
185
+ return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
186
+
187
+ def generate_noise(noise_type, n, s=64):
188
+ """Generate n samples of a given noise type."""
189
+ imgs = []
190
+ rng = np.random.RandomState(42)
191
+ for _ in range(n):
192
+ if noise_type == 0: img = torch.randn(3,s,s)
193
+ elif noise_type == 1: img = torch.rand(3,s,s)*2-1
194
+ elif noise_type == 2: img = (torch.rand(3,s,s)-0.5)*4
195
+ elif noise_type == 3:
196
+ lam = rng.uniform(0.5, 20.0)
197
+ img = torch.poisson(torch.full((3,s,s), lam))/lam - 1.0
198
+ elif noise_type == 4: img = _pink((3,s,s)); img = img/(img.std()+1e-8)
199
+ elif noise_type == 5: img = _brown((3,s,s)); img = img/(img.std()+1e-8)
200
+ elif noise_type == 6:
201
+ img = torch.where(torch.rand(3,s,s)>0.5, torch.ones(3,s,s)*2, -torch.ones(3,s,s)*2)
202
+ img = img + torch.randn(3,s,s)*0.1
203
+ elif noise_type == 7: img = torch.randn(3,s,s)*(torch.rand(3,s,s)>0.9).float()*3
204
+ elif noise_type == 8:
205
+ b = rng.randint(2,16); sm = torch.randn(3,s//b+1,s//b+1)
206
+ img = F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0)
207
+ elif noise_type == 9:
208
+ gy = torch.linspace(-2,2,s).unsqueeze(1).expand(s,s)
209
+ gx = torch.linspace(-2,2,s).unsqueeze(0).expand(s,s)
210
+ a = rng.uniform(0, 2*math.pi)
211
+ img = (math.cos(a)*gx + math.sin(a)*gy).unsqueeze(0).expand(3,-1,-1) + torch.randn(3,s,s)*0.5
212
+ elif noise_type == 10:
213
+ cs = rng.randint(2,16); cy = torch.arange(s)//cs; cx = torch.arange(s)//cs
214
+ img = ((cy.unsqueeze(1)+cx.unsqueeze(0))%2).float().unsqueeze(0).expand(3,-1,-1)*2-1 + torch.randn(3,s,s)*0.3
215
+ elif noise_type == 11:
216
+ alpha = rng.uniform(0.2, 0.8)
217
+ img = alpha*torch.randn(3,s,s) + (1-alpha)*(torch.rand(3,s,s)*2-1)
218
+ elif noise_type == 12:
219
+ img = torch.zeros(3,s,s); h2 = s//2
220
+ img[:,:h2,:h2] = torch.randn(3,h2,h2)
221
+ img[:,:h2,h2:] = torch.rand(3,h2,h2)*2-1
222
+ img[:,h2:,:h2] = _pink((3,h2,h2))/2
223
+ img[:,h2:,h2:] = torch.where(torch.rand(3,h2,h2)>0.5, torch.ones(3,h2,h2), -torch.ones(3,h2,h2))
224
+ elif noise_type == 13: img = torch.tan(math.pi*(torch.rand(3,s,s)-0.5)).clamp(-3,3)
225
+ elif noise_type == 14: img = torch.empty(3,s,s).exponential_(1.0)-1.0
226
+ elif noise_type == 15:
227
+ u = torch.rand(3,s,s)-0.5; img = -torch.sign(u)*torch.log1p(-2*u.abs())
228
+ else: img = torch.randn(3,s,s)
229
+ imgs.append(img.clamp(-4,4))
230
+ return torch.stack(imgs)
231
+
232
+
233
+ # ════════════════════════════════════════════════════════════════
234
+ # DIAGNOSTIC TESTS
235
+ # ════════════════════════════════════════════════════════════════
236
+
237
+ def test_1_per_type_mse(model, n=100, s=64):
238
+ """Per-type reconstruction MSE."""
239
+ print(f"\n{'='*70}")
240
+ print("TEST 1: Per-Type Reconstruction MSE (100 samples each)")
241
+ print(f"{'='*70}")
242
+ results = {}
243
+ model.eval()
244
+ with torch.no_grad():
245
+ for t in range(16):
246
+ imgs = generate_noise(t, n, s).to(DEVICE)
247
+ out = model(imgs)
248
+ mse = F.mse_loss(out['recon'], imgs, reduction='none').mean(dim=(1,2,3))
249
+ results[NOISE_NAMES[t]] = {
250
+ 'mean': mse.mean().item(),
251
+ 'std': mse.std().item(),
252
+ 'min': mse.min().item(),
253
+ 'max': mse.max().item(),
254
+ }
255
+ print(f" {NOISE_NAMES[t]:18s}: {mse.mean():.6f} Β± {mse.std():.6f} "
256
+ f"[{mse.min():.6f} β€” {mse.max():.6f}]")
257
+ return results
258
+
259
+
260
+ def test_2_byte_accuracy(model, n=100, s=64):
261
+ """Byte-level reconstruction accuracy per type."""
262
+ print(f"\n{'='*70}")
263
+ print("TEST 2: Byte-Level Accuracy (quantized to 256 levels)")
264
+ print(f"{'='*70}")
265
+ results = {}
266
+ model.eval()
267
+ with torch.no_grad():
268
+ for t in range(16):
269
+ imgs = generate_noise(t, n, s).to(DEVICE)
270
+ out = model(imgs)
271
+ # Quantize to 256 levels
272
+ orig_q = ((imgs + 4) / 8 * 255).round().clamp(0, 255).long()
273
+ recon_q = ((out['recon'] + 4) / 8 * 255).round().clamp(0, 255).long()
274
+ acc = (orig_q == recon_q).float().mean().item()
275
+ # Within-1 accuracy
276
+ acc1 = ((orig_q - recon_q).abs() <= 1).float().mean().item()
277
+ results[NOISE_NAMES[t]] = {'exact': acc, 'within_1': acc1}
278
+ print(f" {NOISE_NAMES[t]:18s}: exact={acc*100:5.1f}% Β±1={acc1*100:5.1f}%")
279
+ return results
280
+
281
+
282
+ def test_3_geometric_fingerprint(model, n=64, s=64):
283
+ """Geometric properties per noise type."""
284
+ print(f"\n{'='*70}")
285
+ print("TEST 3: Geometric Fingerprint Per Type")
286
+ print(f"{'='*70}")
287
+ D = model.D
288
+ results = {}
289
+ model.eval()
290
+ with torch.no_grad():
291
+ for t in range(16):
292
+ imgs = generate_noise(t, n, s).to(DEVICE)
293
+ out = model(imgs)
294
+ S = out['svd']['S'] # (B, N, D)
295
+ S_mean = S.mean(dim=(0, 1))
296
+ ratio = (S_mean[0] / (S_mean[-1] + 1e-8)).item()
297
+ erank = model.effective_rank(S.reshape(-1, D)).mean().item()
298
+ s0 = S_mean[0].item()
299
+ sd = S_mean[-1].item()
300
+ results[NOISE_NAMES[t]] = {'S0': s0, 'SD': sd, 'ratio': ratio, 'erank': erank}
301
+ print(f" {NOISE_NAMES[t]:18s}: Sβ‚€={s0:.3f} SD={sd:.3f} "
302
+ f"ratio={ratio:.2f} erank={erank:.2f}")
303
+ return results
304
+
305
+
306
+ def test_4_omega_similarity(model, n=32, s=64):
307
+ """Cross-type omega token cosine similarity matrix."""
308
+ print(f"\n{'='*70}")
309
+ print("TEST 4: Cross-Type Omega Token Similarity")
310
+ print(f"{'='*70}")
311
+ D = model.D
312
+ type_centroids = {}
313
+ model.eval()
314
+ with torch.no_grad():
315
+ for t in range(16):
316
+ imgs = generate_noise(t, n, s).to(DEVICE)
317
+ out = model(imgs)
318
+ # Average omega token per type: (D,)
319
+ omega = out['svd']['S'].mean(dim=(0, 1))
320
+ type_centroids[t] = omega
321
+
322
+ # Cosine similarity matrix
323
+ keys = sorted(type_centroids.keys())
324
+ centroids = torch.stack([type_centroids[k] for k in keys])
325
+ centroids_norm = F.normalize(centroids, dim=-1)
326
+ sim_matrix = centroids_norm @ centroids_norm.T
327
+
328
+ # Print matrix
329
+ header = " " + " ".join([f"{NOISE_NAMES[k][:5]:>5s}" for k in keys])
330
+ print(f" {header}")
331
+ for i, ki in enumerate(keys):
332
+ row = f" {NOISE_NAMES[ki]:8s}"
333
+ for j, kj in enumerate(keys):
334
+ v = sim_matrix[i, j].item()
335
+ row += f" {v:5.2f}"
336
+ print(row)
337
+ return sim_matrix.cpu()
338
+
339
+
340
+ def test_5_spectrum_per_type(model, n=64, s=64):
341
+ """Singular value spectrum analysis per type."""
342
+ print(f"\n{'='*70}")
343
+ print("TEST 5: Spectrum Profile Per Type")
344
+ print(f"{'='*70}")
345
+ D = model.D
346
+ results = {}
347
+ model.eval()
348
+ with torch.no_grad():
349
+ for t in range(16):
350
+ imgs = generate_noise(t, n, s).to(DEVICE)
351
+ out = model(imgs)
352
+ S_mean = out['svd']['S'].mean(dim=(0, 1))
353
+ total = (S_mean**2).sum()
354
+ cum = 0
355
+ spectrum = []
356
+ for d in range(D):
357
+ e = (S_mean[d]**2).item()
358
+ cum += e
359
+ spectrum.append({'value': S_mean[d].item(), 'energy_pct': cum/total.item()*100})
360
+ results[NOISE_NAMES[t]] = spectrum
361
+
362
+ # Print top-3 and bottom-3 modes per type
363
+ for t in range(16):
364
+ name = NOISE_NAMES[t]
365
+ sp = results[name]
366
+ top = f"S0={sp[0]['value']:.3f}({sp[0]['energy_pct']:.1f}%)"
367
+ mid = f"S7={sp[7]['value']:.3f}({sp[7]['energy_pct']:.1f}%)"
368
+ bot = f"S15={sp[15]['value']:.3f}(100%)"
369
+ print(f" {name:18s}: {top} {mid} {bot}")
370
+ return results
371
+
372
+
373
+ def test_6_reconstruction_grid(model, s=64):
374
+ """Visual reconstruction grid β€” all 16 types."""
375
+ print(f"\n{'='*70}")
376
+ print("TEST 6: Reconstruction Grid (saved to johanna_diagnostic_grid.png)")
377
+ print(f"{'='*70}")
378
+ import matplotlib
379
+ matplotlib.use('Agg')
380
+ import matplotlib.pyplot as plt
381
+
382
+ model.eval()
383
+ fig, axes = plt.subplots(16, 3, figsize=(9, 48))
384
+
385
+ with torch.no_grad():
386
+ for t in range(16):
387
+ img = generate_noise(t, 1, s).to(DEVICE)
388
+ out = model(img)
389
+ recon = out['recon']
390
+ mse = F.mse_loss(recon, img).item()
391
+
392
+ orig_np = img[0].cpu().clamp(-3, 3).permute(1, 2, 0).numpy()
393
+ recon_np = recon[0].cpu().clamp(-3, 3).permute(1, 2, 0).numpy()
394
+ diff_np = (img[0] - recon[0]).abs().cpu().clamp(0, 2).permute(1, 2, 0).numpy()
395
+
396
+ # Normalize for display
397
+ for arr in [orig_np, recon_np]:
398
+ arr -= arr.min(); arr /= (arr.max() + 1e-8)
399
+ diff_np /= (diff_np.max() + 1e-8)
400
+
401
+ axes[t, 0].imshow(orig_np); axes[t, 0].set_ylabel(NOISE_NAMES[t], fontsize=8)
402
+ axes[t, 1].imshow(recon_np)
403
+ axes[t, 2].imshow(diff_np)
404
+ for j in range(3):
405
+ axes[t, j].axis('off')
406
+
407
+ axes[0, 0].set_title('Original', fontsize=9)
408
+ axes[0, 1].set_title('Recon', fontsize=9)
409
+ axes[0, 2].set_title('|Error|', fontsize=9)
410
+ plt.tight_layout()
411
+ plt.savefig('johanna_diagnostic_grid.png', dpi=150, bbox_inches='tight')
412
+ print(" Saved: johanna_diagnostic_grid.png")
413
+ plt.close()
414
+
415
+
416
+ def test_7_real_images(model, s=64):
417
+ """Zero-shot: real images through noise-trained model."""
418
+ print(f"\n{'='*70}")
419
+ print("TEST 7: Zero-Shot Real Image Reconstruction")
420
+ print(f"{'='*70}")
421
+ from datasets import load_dataset
422
+
423
+ ds = load_dataset('zh-plus/tiny-imagenet', split='valid', streaming=True)
424
+ transform = T.Compose([T.ToTensor(), T.Normalize((0.4802,0.4481,0.3975),(0.2770,0.2691,0.2821))])
425
+
426
+ imgs = []
427
+ for i, sample in enumerate(ds):
428
+ img = sample['image'].convert('RGB')
429
+ imgs.append(transform(img))
430
+ if i >= 99:
431
+ break
432
+
433
+ batch = torch.stack(imgs).to(DEVICE)
434
+ model.eval()
435
+ with torch.no_grad():
436
+ out = model(batch)
437
+ mse = F.mse_loss(out['recon'], batch, reduction='none').mean(dim=(1,2,3))
438
+
439
+ print(f" TinyImageNet (100 images, {s}Γ—{s}):")
440
+ print(f" Mean MSE: {mse.mean():.6f}")
441
+ print(f" Std: {mse.std():.6f}")
442
+ print(f" Min/Max: {mse.min():.6f} / {mse.max():.6f}")
443
+ print(f" Fidelity: {(1 - mse.mean())*100:.3f}%")
444
+ return {'mean': mse.mean().item(), 'std': mse.std().item()}
445
+
446
+
447
+ def test_8_text_bytes(model, s=64):
448
+ """Zero-shot: text through noise-trained model."""
449
+ print(f"\n{'='*70}")
450
+ print("TEST 8: Zero-Shot Text Byte Reconstruction")
451
+ print(f"{'='*70}")
452
+
453
+ texts = [
454
+ "Hello, world! This is a test of the Johanna geometric encoder.",
455
+ "The quick brown fox jumps over the lazy dog. 0123456789 ABCDEF",
456
+ "import torch; model = PatchSVAE(); output = model(x)",
457
+ "E = mcΒ² β€” Albert Einstein, theoretical physicist, 1905",
458
+ "To be, or not to be, that is the question. β€” Shakespeare",
459
+ ]
460
+
461
+ n_bytes = 3 * s * s
462
+ model.eval()
463
+
464
+ for text in texts:
465
+ raw = text.encode('utf-8')
466
+ actual_len = min(len(raw), n_bytes)
467
+ if len(raw) < n_bytes:
468
+ raw = raw + b'\x00' * (n_bytes - len(raw))
469
+ else:
470
+ raw = raw[:n_bytes]
471
+
472
+ arr = np.frombuffer(raw, dtype=np.uint8).copy()
473
+ tensor = torch.from_numpy(arr).float()
474
+ tensor = (tensor / 127.5) - 1.0
475
+ tensor = tensor.reshape(1, 3, s, s).to(DEVICE)
476
+
477
+ with torch.no_grad():
478
+ out = model(tensor)
479
+ recon = out['recon']
480
+ mse = F.mse_loss(recon, tensor).item()
481
+
482
+ recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy()
483
+ recovered = recon_bytes[:actual_len].tobytes().decode('utf-8', errors='replace')
484
+
485
+ orig_b = ((tensor.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
486
+ recon_b = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
487
+ exact_acc = (orig_b[:actual_len] == recon_b[:actual_len]).float().mean().item()
488
+
489
+ print(f"\n Input: '{text[:60]}'")
490
+ print(f" Output: '{recovered[:60]}'")
491
+ print(f" MSE: {mse:.6f}")
492
+ print(f" Byte acc: {exact_acc*100:.1f}%")
493
+
494
+
495
+ def test_9_piecemeal(model, s=64):
496
+ """Piecemeal: tile 256Γ—256 noise into 64Γ—64 tiles."""
497
+ print(f"\n{'='*70}")
498
+ print(f"TEST 9: Piecemeal 256β†’{s} Tiled Reconstruction")
499
+ print(f"{'='*70}")
500
+ model.eval()
501
+
502
+ results = {}
503
+ with torch.no_grad():
504
+ for t in [0, 4, 6, 13]: # Gaussian, Pink, Salt-pepper, Cauchy
505
+ img_256 = generate_noise(t, 1, 256).squeeze(0) # (3, 256, 256)
506
+ tiles = []
507
+ gh, gw = 256 // s, 256 // s
508
+ for gy in range(gh):
509
+ for gx in range(gw):
510
+ tile = img_256[:, gy*s:(gy+1)*s, gx*s:(gx+1)*s]
511
+ tiles.append(tile)
512
+ tile_batch = torch.stack(tiles).to(DEVICE)
513
+ out = model(tile_batch)
514
+ recon_tiles = out['recon'].cpu()
515
+
516
+ # Stitch
517
+ recon_full = torch.zeros(3, 256, 256)
518
+ idx = 0
519
+ for gy in range(gh):
520
+ for gx in range(gw):
521
+ recon_full[:, gy*s:(gy+1)*s, gx*s:(gx+1)*s] = recon_tiles[idx]
522
+ idx += 1
523
+
524
+ mse = F.mse_loss(recon_full, img_256).item()
525
+ results[NOISE_NAMES[t]] = mse
526
+ n_tiles = gh * gw
527
+ print(f" {NOISE_NAMES[t]:18s}: {n_tiles} tiles, MSE={mse:.6f}")
528
+ return results
529
+
530
+
531
+ def test_10_signal_survival(model, n=100, s=64):
532
+ """What percentage of the original signal energy survives reconstruction?"""
533
+ print(f"\n{'='*70}")
534
+ print("TEST 10: Signal Energy Survival Rate")
535
+ print(f"{'='*70}")
536
+ model.eval()
537
+ with torch.no_grad():
538
+ for t in range(16):
539
+ imgs = generate_noise(t, n, s).to(DEVICE)
540
+ out = model(imgs)
541
+ recon = out['recon']
542
+ orig_energy = (imgs**2).mean().item()
543
+ recon_energy = (recon**2).mean().item()
544
+ error_energy = ((imgs - recon)**2).mean().item()
545
+ survival = recon_energy / (orig_energy + 1e-8) * 100
546
+ snr = 10 * math.log10(orig_energy / (error_energy + 1e-8))
547
+ print(f" {NOISE_NAMES[t]:18s}: survival={survival:6.1f}% SNR={snr:5.1f}dB "
548
+ f"orig_E={orig_energy:.3f} recon_E={recon_energy:.3f}")
549
+
550
+
551
+ def test_11_alpha_profile(model):
552
+ """Cross-attention alpha analysis."""
553
+ print(f"\n{'='*70}")
554
+ print("TEST 11: Cross-Attention Alpha Profile")
555
+ print(f"{'='*70}")
556
+ for li, layer in enumerate(model.cross_attn):
557
+ alpha = layer.alpha.detach().cpu()
558
+ print(f"\n Layer {li}: mean={alpha.mean():.4f} max={alpha.max():.4f} "
559
+ f"min={alpha.min():.4f} std={alpha.std():.6f}")
560
+ bar_scale = 50 / (alpha.max().item() + 1e-8)
561
+ for d in range(len(alpha)):
562
+ bar = "β–ˆ" * int(alpha[d].item() * bar_scale)
563
+ print(f" Ξ±[{d:2d}]: {alpha[d]:.5f} {bar}")
564
+
565
+
566
+ def test_12_compression_ratio(model, s=64):
567
+ """Actual compression metrics."""
568
+ print(f"\n{'='*70}")
569
+ print("TEST 12: Compression Metrics")
570
+ print(f"{'='*70}")
571
+ D = model.D
572
+ ps = model.patch_size
573
+ n_patches = (s // ps) ** 2
574
+ input_values = 3 * s * s
575
+ latent_values = D * n_patches
576
+ ratio = input_values / latent_values
577
+ print(f" Input: {s}Γ—{s}Γ—3 = {input_values:,} values")
578
+ print(f" Latent: {D}Γ—{n_patches} = {latent_values:,} values (omega tokens)")
579
+ print(f" Ratio: {ratio:.1f}:1 compression")
580
+ print(f" Patches: {n_patches} of {ps}Γ—{ps}")
581
+ print(f" Omega shape: ({D}, {s//ps}, {s//ps})")
582
+
583
+ # Bits per value at different quantization levels
584
+ for bits in [8, 16, 32]:
585
+ input_bytes = input_values * (bits // 8)
586
+ latent_bytes = latent_values * (bits // 8)
587
+ print(f" At {bits}-bit: input={input_bytes/1024:.1f}KB latent={latent_bytes/1024:.1f}KB "
588
+ f"ratio={input_bytes/latent_bytes:.1f}:1")
589
+
590
+
591
+ # ═════════════════════════════════════════════════════════════��══
592
+ # RUN ALL
593
+ # ════════════════════════════════════════════════════════════════
594
+
595
+ def run_all():
596
+ print("=" * 70)
597
+ print("JOHANNA-TINY FULL BATTERY DIAGNOSTIC")
598
+ print("=" * 70)
599
+
600
+ model, cfg = load_model()
601
+ s = cfg.get('img_size', 64)
602
+
603
+ all_results = {}
604
+ all_results['config'] = cfg
605
+
606
+ all_results['per_type_mse'] = test_1_per_type_mse(model, n=100, s=s)
607
+ all_results['byte_accuracy'] = test_2_byte_accuracy(model, n=100, s=s)
608
+ all_results['geometry'] = test_3_geometric_fingerprint(model, n=64, s=s)
609
+ sim_matrix = test_4_omega_similarity(model, n=32, s=s)
610
+ all_results['spectrum'] = test_5_spectrum_per_type(model, n=64, s=s)
611
+ test_6_reconstruction_grid(model, s=s)
612
+ all_results['real_images'] = test_7_real_images(model, s=s)
613
+ test_8_text_bytes(model, s=s)
614
+ all_results['piecemeal'] = test_9_piecemeal(model, s=s)
615
+ test_10_signal_survival(model, n=100, s=s)
616
+ test_11_alpha_profile(model)
617
+ test_12_compression_ratio(model, s=s)
618
+
619
+ # Save results
620
+ out_path = 'johanna_diagnostic_results.json'
621
+ with open(out_path, 'w') as f:
622
+ json.dump(all_results, f, indent=2, default=str)
623
+ print(f"\n Results saved: {out_path}")
624
+
625
+ print(f"\n{'='*70}")
626
+ print("DIAGNOSTIC COMPLETE")
627
+ print(f"{'='*70}")
628
+
629
+
630
+ if __name__ == "__main__":
631
+ run_all()