File size: 7,162 Bytes
95f0e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
import json
# Import your model components
from models.loader import ModelLoader
from models.uncertainty import BlockUncertaintyTracker
class BathymetrySuperResolution:
"""
Bathymetry super-resolution model with uncertainty estimation
"""
def __init__(self, model_type="vqvae", checkpoint_path=None, config_path=None):
"""
Initialize the super-resolution model with uncertainty awareness
Args:
model_type: Type of model ('srcnn', 'gan', or 'vqvae')
checkpoint_path: Path to model checkpoint
config_path: Path to configuration file
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load config if provided
if config_path is not None and os.path.exists(config_path):
with open(config_path, 'r') as f:
self.config = json.load(f)
else:
# Default configuration
self.config = {
"model_type": model_type,
"model_config": {
"in_channels": 1,
"hidden_dims": [32, 64, 128, 256],
"num_embeddings": 512,
"embedding_dim": 256,
"block_size": 4
},
"normalization": {
"mean": -3911.3894,
"std": 1172.8374,
"min": 0.0,
"max": 1.0
}
}
# Initialize model loader
self.model_loader = ModelLoader()
# Load model
if checkpoint_path is not None and os.path.exists(checkpoint_path):
self.model = self.model_loader.load_model(
self.config['model_type'],
checkpoint_path,
config_overrides=self.config.get('model_config', {})
)
else:
raise ValueError("Checkpoint path not provided or invalid")
# Ensure model is in eval mode
self.model.eval()
# Load normalization parameters
self.mean = self.config['normalization']['mean']
self.std = self.config['normalization']['std']
self.min_val = self.config['normalization']['min']
self.max_val = self.config['normalization']['max']
def preprocess(self, data):
"""
Preprocess input data for the model
Args:
data: Input array/image (can be numpy array, PIL Image, or tensor)
Returns:
Preprocessed tensor
"""
# Convert PIL Image to numpy if needed
if isinstance(data, Image.Image):
data = np.array(data)
# Convert numpy to tensor if needed
if isinstance(data, np.ndarray):
tensor = torch.from_numpy(data).float()
else:
tensor = data.float()
# Add batch and channel dimensions if needed
if len(tensor.shape) == 2:
tensor = tensor.unsqueeze(0).unsqueeze(0)
elif len(tensor.shape) == 3:
tensor = tensor.unsqueeze(0)
# Apply normalization
tensor = (tensor - self.mean) / (self.std + 1e-8)
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8)
# Resize if needed (to 32x32)
if tensor.shape[-1] != 32 or tensor.shape[-2] != 32:
tensor = F.interpolate(
tensor,
size=(32, 32),
mode='bicubic',
align_corners=False
)
return tensor.to(self.device)
def denormalize(self, tensor):
"""
Denormalize output tensor
Args:
tensor: Output tensor from model
Returns:
Denormalized tensor in original data range
"""
# Scale from [0,1] back to original range
tensor = tensor * (self.max_val - self.min_val) + self.min_val
# Restore original scale
tensor = tensor * self.std + self.mean
return tensor
def predict(self, data, with_uncertainty=True, confidence_level=0.95):
"""
Generate super-resolution output with uncertainty bounds
Args:
data: Input data (can be numpy array, PIL Image, or tensor)
with_uncertainty: Whether to include uncertainty bounds
confidence_level: Confidence level for uncertainty bounds
Returns:
Tuple of (prediction, lower_bound, upper_bound) if with_uncertainty=True
or just prediction otherwise
"""
# Preprocess input
input_tensor = self.preprocess(data)
with torch.no_grad():
# Run model inference
if with_uncertainty and hasattr(self.model, 'predict_with_uncertainty'):
prediction, lower_bound, upper_bound = self.model.predict_with_uncertainty(
input_tensor, confidence_level
)
# Denormalize outputs
prediction = self.denormalize(prediction)
lower_bound = self.denormalize(lower_bound) if lower_bound is not None else None
upper_bound = self.denormalize(upper_bound) if upper_bound is not None else None
# Convert to numpy
prediction = prediction.cpu().numpy()
lower_bound = lower_bound.cpu().numpy() if lower_bound is not None else None
upper_bound = upper_bound.cpu().numpy() if upper_bound is not None else None
return prediction, lower_bound, upper_bound
else:
# Standard inference
prediction = self.model(input_tensor)
# Denormalize
prediction = self.denormalize(prediction)
# Convert to numpy
prediction = prediction.cpu().numpy()
return prediction
def load_npy(self, file_path):
"""
Load bathymetry data from numpy file
Args:
file_path: Path to .npy file
Returns:
Numpy array containing bathymetry data
"""
try:
return np.load(file_path)
except Exception as e:
raise ValueError(f"Error loading numpy file: {str(e)}")
@staticmethod
def get_uncertainty_width(lower_bound, upper_bound):
"""
Calculate uncertainty width (difference between upper and lower bounds)
Args:
lower_bound: Lower uncertainty bound
upper_bound: Upper uncertainty bound
Returns:
Uncertainty width
"""
if lower_bound is None or upper_bound is None:
return None
return np.mean(upper_bound - lower_bound) |