Jose Marie Antonio Minoza
Initial commit
95f0e22
import torch
import torch.nn as nn
import torch.nn.functional as F
from .uncertainty import BlockUncertaintyTracker
class ResidualAttentionBlock(nn.Module):
"""Residual attention block for capturing spatial dependencies"""
def __init__(self, in_channels):
super().__init__()
# Trunk branch
self.trunk = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
nn.BatchNorm2d(in_channels),
nn.SiLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
nn.BatchNorm2d(in_channels)
)
# Mask branch for attention
self.mask = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels, kernel_size=1),
nn.SiLU(),
nn.Conv2d(in_channels, in_channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
# Trunk branch
trunk_output = self.trunk(x)
# Mask branch for attention weights
attention = self.mask(x)
# Apply attention and residual connection
out = x + attention * trunk_output
return F.silu(out)
class VectorQuantizer(nn.Module):
"""Vector quantizer for discrete latent representation"""
def __init__(self, n_embeddings=512, embedding_dim=256, beta=0.25):
super().__init__()
self.n_embeddings = n_embeddings
self.embedding_dim = embedding_dim
self.beta = beta
# Initialize embeddings
self.embeddings = nn.Parameter(torch.randn(n_embeddings, embedding_dim))
nn.init.uniform_(self.embeddings, -1.0 / n_embeddings, 1.0 / n_embeddings)
# Usage tracking
self.register_buffer('usage', torch.zeros(n_embeddings))
def forward(self, z):
# Reshape input for quantization
z_flattened = z.reshape(-1, self.embedding_dim)
# Calculate distances to embedding vectors
distances = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embeddings**2, dim=1) - \
2 * torch.matmul(z_flattened, self.embeddings.t())
# Find nearest embedding for each input vector
encoding_indices = torch.argmin(distances, dim=1)
# Update usage statistics
if self.training:
with torch.no_grad():
usage = torch.zeros_like(self.usage)
usage.scatter_add_(0, encoding_indices, torch.ones_like(encoding_indices, dtype=torch.float))
self.usage.mul_(0.99).add_(usage, alpha=0.01)
# Get quantized vectors
z_q = self.embeddings[encoding_indices].reshape(z.shape)
# Calculate loss terms
commitment_loss = F.mse_loss(z_q.detach(), z)
codebook_loss = F.mse_loss(z_q, z.detach())
# Combine losses
loss = codebook_loss + self.beta * commitment_loss
# Straight-through estimator
z_q = z + (z_q - z).detach()
if self.training:
return z_q, loss
else:
return z_q
class Encoder(nn.Module):
"""Encoder for VQ-VAE model"""
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], embedding_dim=256):
super().__init__()
# Initial conv layer
layers = [
nn.Conv2d(in_channels, hidden_dims[0], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dims[0]),
nn.SiLU()
]
# Hidden layers with downsampling
for i in range(len(hidden_dims) - 1):
layers.extend([
nn.Conv2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(hidden_dims[i+1]),
nn.SiLU()
])
# Residual attention blocks
for _ in range(2):
layers.append(ResidualAttentionBlock(hidden_dims[-1]))
# Final projection to embedding dimension
layers.extend([
nn.Conv2d(hidden_dims[-1], embedding_dim, kernel_size=1),
nn.BatchNorm2d(embedding_dim)
])
self.encoder = nn.Sequential(*layers)
def forward(self, x):
return self.encoder(x)
class Decoder(nn.Module):
"""Decoder for VQ-VAE model"""
def __init__(self, embedding_dim=256, hidden_dims=[256, 128, 64, 32], out_channels=1):
super().__init__()
# Reverse hidden dims for decoder
hidden_dims = hidden_dims[::-1]
# Initial processing
layers = [
nn.Conv2d(embedding_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dims[0]),
nn.SiLU()
]
# Residual attention blocks
for _ in range(2):
layers.append(ResidualAttentionBlock(hidden_dims[0]))
# Upsampling blocks
for i in range(len(hidden_dims) - 1):
layers.extend([
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1],
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(hidden_dims[i+1]),
nn.SiLU()
])
# Final output layer
layers.append(
nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1)
)
layers.append(nn.Sigmoid())
self.decoder = nn.Sequential(*layers)
def forward(self, x):
return self.decoder(x)
class VQVAE(nn.Module):
"""
Vector Quantized Variational Autoencoder with uncertainty awareness
for bathymetry super-resolution
"""
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256],
num_embeddings=512, embedding_dim=256, block_size=4, alpha=0.1):
super().__init__()
# Initialize block-wise uncertainty tracking
self.uncertainty_tracker = BlockUncertaintyTracker(
block_size=block_size,
alpha=alpha,
decay=0.99,
eps=1e-5
)
# Main model components
self.encoder = Encoder(
in_channels=in_channels,
hidden_dims=hidden_dims,
embedding_dim=embedding_dim
)
self.vq = VectorQuantizer(
n_embeddings=num_embeddings,
embedding_dim=embedding_dim,
beta=0.25
)
self.decoder = Decoder(
embedding_dim=embedding_dim,
hidden_dims=hidden_dims,
out_channels=in_channels
)
def forward(self, x):
"""Forward pass through the model"""
# Encode
z = self.encoder(x)
# Vector quantization
if self.training:
z_q, vq_loss = self.vq(z)
# Decode
reconstruction = self.decoder(z_q)
return reconstruction, vq_loss
else:
z_q = self.vq(z)
# Decode
reconstruction = self.decoder(z_q)
return reconstruction
def train_forward(self, x, y):
"""Training forward pass with uncertainty tracking"""
# Get reconstruction and VQ loss
reconstruction, vq_loss = self.forward(x)
# Calculate reconstruction error
error = torch.abs(reconstruction - y)
# Update uncertainty tracker
self.uncertainty_tracker.update(error)
# Get uncertainty map for loss weighting
uncertainty_map = self.uncertainty_tracker.get_uncertainty(error)
return reconstruction, vq_loss, uncertainty_map
def predict_with_uncertainty(self, x, confidence_level=0.95):
"""
Forward pass with calibrated uncertainty bounds
Args:
x: Input tensor
confidence_level: Confidence level for bounds (default: 0.95)
Returns:
tuple: (reconstruction, lower_bounds, upper_bounds)
"""
self.eval()
with torch.no_grad():
# Get reconstruction
reconstruction = self.forward(x)
# Get calibrated uncertainty bounds
lower_bounds, upper_bounds = self.uncertainty_tracker.get_bounds(
reconstruction, confidence_level
)
return reconstruction, lower_bounds, upper_bounds