| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| INPUT_DIM = L * 5 |
| LATENT_DIM = 2048 |
| HIDDEN_DIM = 1024 |
|
|
| class SparseAE(nn.Module): |
| def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM): |
| super().__init__() |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, latent_dim), |
| nn.ReLU() |
| ) |
|
|
| |
| self.dec_hidden = nn.Linear(latent_dim, hidden_dim) |
|
|
| |
| self.dec_dna = nn.Linear(hidden_dim, L * 4) |
| self.dec_phy = nn.Linear(hidden_dim, L * 1) |
|
|
| def forward(self, dna, phy): |
| B = dna.size(0) |
|
|
| x = torch.cat( |
| [dna.reshape(B, -1), phy.reshape(B, -1)], |
| dim=1 |
| ) |
|
|
| h = self.encoder(x) |
| dec = F.relu(self.dec_hidden(h)) |
|
|
| recon_dna = self.dec_dna(dec).reshape(B, L, 4) |
| recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) |
|
|
| return recon_dna, recon_phy, h |
|
|
|
|
| |
| L = 50 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = SparseAE().to(device) |
|
|
| |
| model.load_state_dict(torch.load("sparse_ae_50bp_epoch3.pt", map_location=device)) |
| model.eval() |
| print("Model loaded.") |
|
|
| |
| print("Generating test data...") |
| |
| N_SAMPLES = 10000 |
| probs = torch.tensor([0.25, 0.25, 0.25, 0.25]) |
| test_dna_idx = torch.multinomial(probs, N_SAMPLES * L, replacement=True).view(N_SAMPLES, L) |
| test_dna = F.one_hot(test_dna_idx, num_classes=4).float().to(device) |
| test_phy = torch.randn(N_SAMPLES, L).to(device) |
|
|
| |
| print("Running inference...") |
| with torch.no_grad(): |
| |
| |
| |
| B = test_dna.size(0) |
| x = torch.cat([test_dna.reshape(B, -1), test_phy.reshape(B, -1)], dim=1) |
| h = model.encoder(x) |
|
|
|
|
| |
| h_np = h.cpu().numpy() |
|
|
| |
| |
| neuron_firing_counts = np.sum(h_np > 0.1, axis=0) |
|
|
| |
| sorted_counts = np.sort(neuron_firing_counts)[::-1] |
|
|
| print("\n--- VOCABULARY HEALTH CHECK ---") |
| print(f"Total Neurons: 2048") |
| print(f"Dead Neurons (Never fire): {np.sum(neuron_firing_counts == 0)}") |
| print(f"Rare Neurons (Fire < 10 times): {np.sum(neuron_firing_counts < 10)}") |
| print(f"Common Neurons (Fire > 1000 times): {np.sum(neuron_firing_counts > 1000)}") |
|
|