File size: 8,151 Bytes
95f0e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1402dca
95f0e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import gradio as gr
from PIL import Image
from pathlib import Path

# Import the inference module
from inference import BathymetrySuperResolution

# Define checkpoint and config paths
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "checkpoints")
MODEL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "calibrated.pth")
CONFIG_PATH = os.environ.get("CONFIG_PATH", "config.json")

# Initialize model
try:
    model = BathymetrySuperResolution(
        model_type="vqvae",
        checkpoint_path=MODEL_CHECKPOINT,
        config_path=CONFIG_PATH
    )
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {str(e)}")
    model = None
    model_loaded = False

def process_upload(file, confidence_level, block_size, model_type):
    """Process uploaded bathymetry file"""
    if file is None:
        return None, "Please upload a file."
    
    try:
        # Check if the model is loaded
        if not model_loaded:
            return None, "Model not loaded. Please check server logs."
        
        # Load the data
        if file.name.endswith('.npy'):
            data = np.load(file.name)
        else:
            # Try to load as an image
            img = Image.open(file.name).convert('L')
            data = np.array(img)
        
        # Update model configuration if needed
        if model.config['model_type'] != model_type or model.config['model_config']['block_size'] != block_size:
            # In a real app, you would reload the model or adjust the configuration
            pass
        
        # Run the prediction
        prediction, lower_bound, upper_bound = model.predict(
            data, 
            with_uncertainty=True,
            confidence_level=confidence_level/100.0  # Convert percentage to fraction
        )
        
        # Calculate uncertainty width
        uncertainty_width = model.get_uncertainty_width(lower_bound, upper_bound)
        
        # Create visualization
        fig = plt.figure(figsize=(15, 10))
        
        # Original input (resized to 32x32 if needed)
        ax1 = fig.add_subplot(231)
        if data.shape != (32, 32):
            from scipy.ndimage import zoom
            zoom_factor = 32 / max(data.shape)
            input_data = zoom(data, zoom_factor)
        else:
            input_data = data
        im1 = ax1.imshow(input_data, cmap=cm.viridis)
        ax1.set_title("Input (32x32)")
        plt.colorbar(im1, ax=ax1)
        
        # Super-resolution output
        ax2 = fig.add_subplot(232)
        im2 = ax2.imshow(prediction[0, 0], cmap=cm.viridis)
        ax2.set_title("Super-Resolution (64x64)")
        plt.colorbar(im2, ax=ax2)
        
        # Lower bound
        ax3 = fig.add_subplot(233)
        im3 = ax3.imshow(lower_bound[0, 0], cmap=cm.viridis)
        ax3.set_title(f"Lower Bound ({confidence_level}% CI)")
        plt.colorbar(im3, ax=ax3)
        
        # Upper bound
        ax4 = fig.add_subplot(234)
        im4 = ax4.imshow(upper_bound[0, 0], cmap=cm.viridis)
        ax4.set_title(f"Upper Bound ({confidence_level}% CI)")
        plt.colorbar(im4, ax=ax4)
        
        # Uncertainty width visualization
        ax5 = fig.add_subplot(235)
        uncertainty_map = upper_bound[0, 0] - lower_bound[0, 0]
        im5 = ax5.imshow(uncertainty_map, cmap='hot')
        ax5.set_title("Uncertainty Width")
        plt.colorbar(im5, ax=ax5)
        
        # 3D surface plot
        ax6 = fig.add_subplot(236, projection='3d')
        x = np.arange(0, prediction.shape[2])
        y = np.arange(0, prediction.shape[3])
        X, Y = np.meshgrid(x, y)
        surf = ax6.plot_surface(X, Y, prediction[0, 0], cmap=cm.viridis, 
                              linewidth=0, antialiased=True)
        ax6.set_title("3D Bathymetry")
        
        plt.tight_layout()
        
        # Return the figure and a summary text
        summary = f"""
        **Super-Resolution Results:**
        - **Model Type**: {model_type.upper()}
        - **Block Size**: {block_size}×{block_size}
        - **Confidence Level**: {confidence_level}%
        - **Average Uncertainty Width**: {uncertainty_width:.4f}
        - **Input Shape**: {data.shape}
        - **Output Shape**: {prediction.shape[2:]}
        """
        
        return fig, summary
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"Error processing file: {str(e)}"

def create_sample_data():
    """Create a sample bathymetry data file for demonstration"""
    # Create a synthetic bathymetry profile with features
    x = np.linspace(0, 1, 32)
    y = np.linspace(0, 1, 32)
    xx, yy = np.meshgrid(x, y)
    
    # Create a surface with a ridge and a valley
    z = -4000 + 500 * np.sin(10 * xx) * np.cos(8 * yy) + 300 * np.exp(-((xx-0.3)**2 + (yy-0.7)**2)/0.1)
    
    # Save to a temporary file
    sample_dir = Path("samples")
    sample_dir.mkdir(exist_ok=True)
    sample_path = sample_dir / "sample.npy"
    np.save(sample_path, z)
    
    return str(sample_path)

# Create the Gradio interface
with gr.Blocks(title="Bathymetry Super-Resolution") as demo:
    gr.Markdown("""
    # Bathymetry Super-Resolution with Uncertainty Quantification
    
    This application demonstrates super-resolution of ocean floor (bathymetry) data with uncertainty estimates.
    Upload a bathymetry file (NPY or image) to see the enhanced resolution output with confidence intervals.
    
    The model uses a **Vector Quantized Variational Autoencoder (VQ-VAE)** with **block-based uncertainty quantification**.
    """)
    
    with gr.Row():
        with gr.Column():
            input_file = gr.File(label="Upload Bathymetry File (.npy or image)")
            
            with gr.Row():
                confidence_level = gr.Slider(
                    minimum=80, maximum=99, value=95, step=1,
                    label="Confidence Level (%)"
                )
                
                block_size = gr.Dropdown(
                    choices=[1, 2, 4, 8, 64], value=4,
                    label="Block Size"
                )
                
                model_type = gr.Dropdown(
                    choices=["vqvae", "srcnn", "gan"], value="vqvae",
                    label="Model Type"
                )
            
            with gr.Row():
                process_btn = gr.Button("Generate Super-Resolution")
                sample_btn = gr.Button("Load Sample Data")
        
        with gr.Column():
            output_plots = gr.Plot(label="Super-Resolution Results")
            output_text = gr.Markdown(label="Summary")
    
    # Set up button actions
    process_btn.click(
        fn=process_upload, 
        inputs=[input_file, confidence_level, block_size, model_type], 
        outputs=[output_plots, output_text]
    )
    
    # Sample data generation
    sample_btn.click(
        fn=lambda: gr.update(value=create_sample_data()),
        inputs=None,
        outputs=input_file
    )
    
    gr.Markdown("""
    ## About This Model
    
    This model enhances the resolution of bathymetric data from 32×32 to 64×64 while providing uncertainty estimates.
    It was trained on bathymetry data from multiple ocean regions including the Eastern Pacific Basin, Western Pacific Region, and Indian Ocean Basin.
    
    The uncertainty estimates help identify areas where the model is less confident in its predictions, which is crucial for:
    - Risk assessment in coastal hazard modeling
    - Climate change impact analysis
    - Tsunami propagation simulation
        
    ## Model Performance
    
    | Model | SSIM | PSNR | MSE | MAE | UWidth | CalErr |
    |-------|------|------|-----|-----|--------|--------|
    | UA-VQ-VAE | 0.9433 | 26.8779 | 0.0021 | 0.0317 | 0.1046 | 0.0664 |
    """)

# Launch the demo
if __name__ == "__main__":
    if model_loaded:
        print("Model loaded successfully. Starting Gradio interface.")
    else:
        print("Warning: Model not loaded. Demo will display errors when processing files.")
        
    demo.launch()