import torch import torch.nn as nn import torch.nn.functional as F from .uncertainty import BlockUncertaintyTracker class ResidualAttentionBlock(nn.Module): """Residual attention block for capturing spatial dependencies""" def __init__(self, in_channels): super().__init__() # Trunk branch self.trunk = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0), nn.BatchNorm2d(in_channels), nn.SiLU(), nn.ReflectionPad2d(1), nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0), nn.BatchNorm2d(in_channels) ) # Mask branch for attention self.mask = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.SiLU(), nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.Sigmoid() ) def forward(self, x): # Trunk branch trunk_output = self.trunk(x) # Mask branch for attention weights attention = self.mask(x) # Apply attention and residual connection out = x + attention * trunk_output return F.silu(out) class VectorQuantizer(nn.Module): """Vector quantizer for discrete latent representation""" def __init__(self, n_embeddings=512, embedding_dim=256, beta=0.25): super().__init__() self.n_embeddings = n_embeddings self.embedding_dim = embedding_dim self.beta = beta # Initialize embeddings self.embeddings = nn.Parameter(torch.randn(n_embeddings, embedding_dim)) nn.init.uniform_(self.embeddings, -1.0 / n_embeddings, 1.0 / n_embeddings) # Usage tracking self.register_buffer('usage', torch.zeros(n_embeddings)) def forward(self, z): # Reshape input for quantization z_flattened = z.reshape(-1, self.embedding_dim) # Calculate distances to embedding vectors distances = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ torch.sum(self.embeddings**2, dim=1) - \ 2 * torch.matmul(z_flattened, self.embeddings.t()) # Find nearest embedding for each input vector encoding_indices = torch.argmin(distances, dim=1) # Update usage statistics if self.training: with torch.no_grad(): usage = torch.zeros_like(self.usage) usage.scatter_add_(0, encoding_indices, torch.ones_like(encoding_indices, dtype=torch.float)) self.usage.mul_(0.99).add_(usage, alpha=0.01) # Get quantized vectors z_q = self.embeddings[encoding_indices].reshape(z.shape) # Calculate loss terms commitment_loss = F.mse_loss(z_q.detach(), z) codebook_loss = F.mse_loss(z_q, z.detach()) # Combine losses loss = codebook_loss + self.beta * commitment_loss # Straight-through estimator z_q = z + (z_q - z).detach() if self.training: return z_q, loss else: return z_q class Encoder(nn.Module): """Encoder for VQ-VAE model""" def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], embedding_dim=256): super().__init__() # Initial conv layer layers = [ nn.Conv2d(in_channels, hidden_dims[0], kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(hidden_dims[0]), nn.SiLU() ] # Hidden layers with downsampling for i in range(len(hidden_dims) - 1): layers.extend([ nn.Conv2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[i+1]), nn.SiLU() ]) # Residual attention blocks for _ in range(2): layers.append(ResidualAttentionBlock(hidden_dims[-1])) # Final projection to embedding dimension layers.extend([ nn.Conv2d(hidden_dims[-1], embedding_dim, kernel_size=1), nn.BatchNorm2d(embedding_dim) ]) self.encoder = nn.Sequential(*layers) def forward(self, x): return self.encoder(x) class Decoder(nn.Module): """Decoder for VQ-VAE model""" def __init__(self, embedding_dim=256, hidden_dims=[256, 128, 64, 32], out_channels=1): super().__init__() # Reverse hidden dims for decoder hidden_dims = hidden_dims[::-1] # Initial processing layers = [ nn.Conv2d(embedding_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(hidden_dims[0]), nn.SiLU() ] # Residual attention blocks for _ in range(2): layers.append(ResidualAttentionBlock(hidden_dims[0])) # Upsampling blocks for i in range(len(hidden_dims) - 1): layers.extend([ nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[i+1]), nn.SiLU() ]) # Final output layer layers.append( nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1) ) layers.append(nn.Sigmoid()) self.decoder = nn.Sequential(*layers) def forward(self, x): return self.decoder(x) class VQVAE(nn.Module): """ Vector Quantized Variational Autoencoder with uncertainty awareness for bathymetry super-resolution """ def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], num_embeddings=512, embedding_dim=256, block_size=4, alpha=0.1): super().__init__() # Initialize block-wise uncertainty tracking self.uncertainty_tracker = BlockUncertaintyTracker( block_size=block_size, alpha=alpha, decay=0.99, eps=1e-5 ) # Main model components self.encoder = Encoder( in_channels=in_channels, hidden_dims=hidden_dims, embedding_dim=embedding_dim ) self.vq = VectorQuantizer( n_embeddings=num_embeddings, embedding_dim=embedding_dim, beta=0.25 ) self.decoder = Decoder( embedding_dim=embedding_dim, hidden_dims=hidden_dims, out_channels=in_channels ) def forward(self, x): """Forward pass through the model""" # Encode z = self.encoder(x) # Vector quantization if self.training: z_q, vq_loss = self.vq(z) # Decode reconstruction = self.decoder(z_q) return reconstruction, vq_loss else: z_q = self.vq(z) # Decode reconstruction = self.decoder(z_q) return reconstruction def train_forward(self, x, y): """Training forward pass with uncertainty tracking""" # Get reconstruction and VQ loss reconstruction, vq_loss = self.forward(x) # Calculate reconstruction error error = torch.abs(reconstruction - y) # Update uncertainty tracker self.uncertainty_tracker.update(error) # Get uncertainty map for loss weighting uncertainty_map = self.uncertainty_tracker.get_uncertainty(error) return reconstruction, vq_loss, uncertainty_map def predict_with_uncertainty(self, x, confidence_level=0.95): """ Forward pass with calibrated uncertainty bounds Args: x: Input tensor confidence_level: Confidence level for bounds (default: 0.95) Returns: tuple: (reconstruction, lower_bounds, upper_bounds) """ self.eval() with torch.no_grad(): # Get reconstruction reconstruction = self.forward(x) # Get calibrated uncertainty bounds lower_bounds, upper_bounds = self.uncertainty_tracker.get_bounds( reconstruction, confidence_level ) return reconstruction, lower_bounds, upper_bounds