|
import os |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import json |
|
|
|
|
|
from models.loader import ModelLoader |
|
from models.uncertainty import BlockUncertaintyTracker |
|
|
|
class BathymetrySuperResolution: |
|
""" |
|
Bathymetry super-resolution model with uncertainty estimation |
|
""" |
|
def __init__(self, model_type="vqvae", checkpoint_path=None, config_path=None): |
|
""" |
|
Initialize the super-resolution model with uncertainty awareness |
|
|
|
Args: |
|
model_type: Type of model ('srcnn', 'gan', or 'vqvae') |
|
checkpoint_path: Path to model checkpoint |
|
config_path: Path to configuration file |
|
""" |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
if config_path is not None and os.path.exists(config_path): |
|
with open(config_path, 'r') as f: |
|
self.config = json.load(f) |
|
else: |
|
|
|
self.config = { |
|
"model_type": model_type, |
|
"model_config": { |
|
"in_channels": 1, |
|
"hidden_dims": [32, 64, 128, 256], |
|
"num_embeddings": 512, |
|
"embedding_dim": 256, |
|
"block_size": 4 |
|
}, |
|
"normalization": { |
|
"mean": -3911.3894, |
|
"std": 1172.8374, |
|
"min": 0.0, |
|
"max": 1.0 |
|
} |
|
} |
|
|
|
|
|
self.model_loader = ModelLoader() |
|
|
|
|
|
if checkpoint_path is not None and os.path.exists(checkpoint_path): |
|
self.model = self.model_loader.load_model( |
|
self.config['model_type'], |
|
checkpoint_path, |
|
config_overrides=self.config.get('model_config', {}) |
|
) |
|
else: |
|
raise ValueError("Checkpoint path not provided or invalid") |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
self.mean = self.config['normalization']['mean'] |
|
self.std = self.config['normalization']['std'] |
|
self.min_val = self.config['normalization']['min'] |
|
self.max_val = self.config['normalization']['max'] |
|
|
|
def preprocess(self, data): |
|
""" |
|
Preprocess input data for the model |
|
|
|
Args: |
|
data: Input array/image (can be numpy array, PIL Image, or tensor) |
|
|
|
Returns: |
|
Preprocessed tensor |
|
""" |
|
|
|
if isinstance(data, Image.Image): |
|
data = np.array(data) |
|
|
|
|
|
if isinstance(data, np.ndarray): |
|
tensor = torch.from_numpy(data).float() |
|
else: |
|
tensor = data.float() |
|
|
|
|
|
if len(tensor.shape) == 2: |
|
tensor = tensor.unsqueeze(0).unsqueeze(0) |
|
elif len(tensor.shape) == 3: |
|
tensor = tensor.unsqueeze(0) |
|
|
|
|
|
tensor = (tensor - self.mean) / (self.std + 1e-8) |
|
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8) |
|
|
|
|
|
if tensor.shape[-1] != 32 or tensor.shape[-2] != 32: |
|
tensor = F.interpolate( |
|
tensor, |
|
size=(32, 32), |
|
mode='bicubic', |
|
align_corners=False |
|
) |
|
|
|
return tensor.to(self.device) |
|
|
|
def denormalize(self, tensor): |
|
""" |
|
Denormalize output tensor |
|
|
|
Args: |
|
tensor: Output tensor from model |
|
|
|
Returns: |
|
Denormalized tensor in original data range |
|
""" |
|
|
|
tensor = tensor * (self.max_val - self.min_val) + self.min_val |
|
|
|
|
|
tensor = tensor * self.std + self.mean |
|
|
|
return tensor |
|
|
|
def predict(self, data, with_uncertainty=True, confidence_level=0.95): |
|
""" |
|
Generate super-resolution output with uncertainty bounds |
|
|
|
Args: |
|
data: Input data (can be numpy array, PIL Image, or tensor) |
|
with_uncertainty: Whether to include uncertainty bounds |
|
confidence_level: Confidence level for uncertainty bounds |
|
|
|
Returns: |
|
Tuple of (prediction, lower_bound, upper_bound) if with_uncertainty=True |
|
or just prediction otherwise |
|
""" |
|
|
|
input_tensor = self.preprocess(data) |
|
|
|
with torch.no_grad(): |
|
|
|
if with_uncertainty and hasattr(self.model, 'predict_with_uncertainty'): |
|
prediction, lower_bound, upper_bound = self.model.predict_with_uncertainty( |
|
input_tensor, confidence_level |
|
) |
|
|
|
|
|
prediction = self.denormalize(prediction) |
|
lower_bound = self.denormalize(lower_bound) if lower_bound is not None else None |
|
upper_bound = self.denormalize(upper_bound) if upper_bound is not None else None |
|
|
|
|
|
prediction = prediction.cpu().numpy() |
|
lower_bound = lower_bound.cpu().numpy() if lower_bound is not None else None |
|
upper_bound = upper_bound.cpu().numpy() if upper_bound is not None else None |
|
|
|
return prediction, lower_bound, upper_bound |
|
else: |
|
|
|
prediction = self.model(input_tensor) |
|
|
|
|
|
prediction = self.denormalize(prediction) |
|
|
|
|
|
prediction = prediction.cpu().numpy() |
|
|
|
return prediction |
|
|
|
def load_npy(self, file_path): |
|
""" |
|
Load bathymetry data from numpy file |
|
|
|
Args: |
|
file_path: Path to .npy file |
|
|
|
Returns: |
|
Numpy array containing bathymetry data |
|
""" |
|
try: |
|
return np.load(file_path) |
|
except Exception as e: |
|
raise ValueError(f"Error loading numpy file: {str(e)}") |
|
|
|
@staticmethod |
|
def get_uncertainty_width(lower_bound, upper_bound): |
|
""" |
|
Calculate uncertainty width (difference between upper and lower bounds) |
|
|
|
Args: |
|
lower_bound: Lower uncertainty bound |
|
upper_bound: Upper uncertainty bound |
|
|
|
Returns: |
|
Uncertainty width |
|
""" |
|
if lower_bound is None or upper_bound is None: |
|
return None |
|
|
|
return np.mean(upper_bound - lower_bound) |