|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .uncertainty import BlockUncertaintyTracker |
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
"""Residual attention block for capturing spatial dependencies""" |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
|
|
|
|
self.trunk = nn.Sequential( |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0), |
|
nn.BatchNorm2d(in_channels), |
|
nn.SiLU(), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0), |
|
nn.BatchNorm2d(in_channels) |
|
) |
|
|
|
|
|
self.mask = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, in_channels, kernel_size=1), |
|
nn.SiLU(), |
|
nn.Conv2d(in_channels, in_channels, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
trunk_output = self.trunk(x) |
|
|
|
|
|
attention = self.mask(x) |
|
|
|
|
|
out = x + attention * trunk_output |
|
return F.silu(out) |
|
|
|
class VectorQuantizer(nn.Module): |
|
"""Vector quantizer for discrete latent representation""" |
|
def __init__(self, n_embeddings=512, embedding_dim=256, beta=0.25): |
|
super().__init__() |
|
self.n_embeddings = n_embeddings |
|
self.embedding_dim = embedding_dim |
|
self.beta = beta |
|
|
|
|
|
self.embeddings = nn.Parameter(torch.randn(n_embeddings, embedding_dim)) |
|
nn.init.uniform_(self.embeddings, -1.0 / n_embeddings, 1.0 / n_embeddings) |
|
|
|
|
|
self.register_buffer('usage', torch.zeros(n_embeddings)) |
|
|
|
def forward(self, z): |
|
|
|
z_flattened = z.reshape(-1, self.embedding_dim) |
|
|
|
|
|
distances = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ |
|
torch.sum(self.embeddings**2, dim=1) - \ |
|
2 * torch.matmul(z_flattened, self.embeddings.t()) |
|
|
|
|
|
encoding_indices = torch.argmin(distances, dim=1) |
|
|
|
|
|
if self.training: |
|
with torch.no_grad(): |
|
usage = torch.zeros_like(self.usage) |
|
usage.scatter_add_(0, encoding_indices, torch.ones_like(encoding_indices, dtype=torch.float)) |
|
self.usage.mul_(0.99).add_(usage, alpha=0.01) |
|
|
|
|
|
z_q = self.embeddings[encoding_indices].reshape(z.shape) |
|
|
|
|
|
commitment_loss = F.mse_loss(z_q.detach(), z) |
|
codebook_loss = F.mse_loss(z_q, z.detach()) |
|
|
|
|
|
loss = codebook_loss + self.beta * commitment_loss |
|
|
|
|
|
z_q = z + (z_q - z).detach() |
|
|
|
if self.training: |
|
return z_q, loss |
|
else: |
|
return z_q |
|
|
|
class Encoder(nn.Module): |
|
"""Encoder for VQ-VAE model""" |
|
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], embedding_dim=256): |
|
super().__init__() |
|
|
|
|
|
layers = [ |
|
nn.Conv2d(in_channels, hidden_dims[0], kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(hidden_dims[0]), |
|
nn.SiLU() |
|
] |
|
|
|
|
|
for i in range(len(hidden_dims) - 1): |
|
layers.extend([ |
|
nn.Conv2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(hidden_dims[i+1]), |
|
nn.SiLU() |
|
]) |
|
|
|
|
|
for _ in range(2): |
|
layers.append(ResidualAttentionBlock(hidden_dims[-1])) |
|
|
|
|
|
layers.extend([ |
|
nn.Conv2d(hidden_dims[-1], embedding_dim, kernel_size=1), |
|
nn.BatchNorm2d(embedding_dim) |
|
]) |
|
|
|
self.encoder = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.encoder(x) |
|
|
|
class Decoder(nn.Module): |
|
"""Decoder for VQ-VAE model""" |
|
def __init__(self, embedding_dim=256, hidden_dims=[256, 128, 64, 32], out_channels=1): |
|
super().__init__() |
|
|
|
|
|
hidden_dims = hidden_dims[::-1] |
|
|
|
|
|
layers = [ |
|
nn.Conv2d(embedding_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(hidden_dims[0]), |
|
nn.SiLU() |
|
] |
|
|
|
|
|
for _ in range(2): |
|
layers.append(ResidualAttentionBlock(hidden_dims[0])) |
|
|
|
|
|
for i in range(len(hidden_dims) - 1): |
|
layers.extend([ |
|
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], |
|
kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(hidden_dims[i+1]), |
|
nn.SiLU() |
|
]) |
|
|
|
|
|
layers.append( |
|
nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1) |
|
) |
|
layers.append(nn.Sigmoid()) |
|
|
|
self.decoder = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.decoder(x) |
|
|
|
class VQVAE(nn.Module): |
|
""" |
|
Vector Quantized Variational Autoencoder with uncertainty awareness |
|
for bathymetry super-resolution |
|
""" |
|
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], |
|
num_embeddings=512, embedding_dim=256, block_size=4, alpha=0.1): |
|
super().__init__() |
|
|
|
|
|
self.uncertainty_tracker = BlockUncertaintyTracker( |
|
block_size=block_size, |
|
alpha=alpha, |
|
decay=0.99, |
|
eps=1e-5 |
|
) |
|
|
|
|
|
self.encoder = Encoder( |
|
in_channels=in_channels, |
|
hidden_dims=hidden_dims, |
|
embedding_dim=embedding_dim |
|
) |
|
|
|
self.vq = VectorQuantizer( |
|
n_embeddings=num_embeddings, |
|
embedding_dim=embedding_dim, |
|
beta=0.25 |
|
) |
|
|
|
self.decoder = Decoder( |
|
embedding_dim=embedding_dim, |
|
hidden_dims=hidden_dims, |
|
out_channels=in_channels |
|
) |
|
|
|
def forward(self, x): |
|
"""Forward pass through the model""" |
|
|
|
z = self.encoder(x) |
|
|
|
|
|
if self.training: |
|
z_q, vq_loss = self.vq(z) |
|
|
|
|
|
reconstruction = self.decoder(z_q) |
|
|
|
return reconstruction, vq_loss |
|
else: |
|
z_q = self.vq(z) |
|
|
|
|
|
reconstruction = self.decoder(z_q) |
|
|
|
return reconstruction |
|
|
|
def train_forward(self, x, y): |
|
"""Training forward pass with uncertainty tracking""" |
|
|
|
reconstruction, vq_loss = self.forward(x) |
|
|
|
|
|
error = torch.abs(reconstruction - y) |
|
|
|
|
|
self.uncertainty_tracker.update(error) |
|
|
|
|
|
uncertainty_map = self.uncertainty_tracker.get_uncertainty(error) |
|
|
|
return reconstruction, vq_loss, uncertainty_map |
|
|
|
def predict_with_uncertainty(self, x, confidence_level=0.95): |
|
""" |
|
Forward pass with calibrated uncertainty bounds |
|
|
|
Args: |
|
x: Input tensor |
|
confidence_level: Confidence level for bounds (default: 0.95) |
|
|
|
Returns: |
|
tuple: (reconstruction, lower_bounds, upper_bounds) |
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
reconstruction = self.forward(x) |
|
|
|
|
|
lower_bounds, upper_bounds = self.uncertainty_tracker.get_bounds( |
|
reconstruction, confidence_level |
|
) |
|
|
|
return reconstruction, lower_bounds, upper_bounds |