File size: 24,098 Bytes
5284cae
0d67e1c
5284cae
 
 
0d67e1c
5284cae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
 
5284cae
0d67e1c
 
 
 
 
 
 
5284cae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
5284cae
 
 
0d67e1c
 
 
5284cae
 
0d67e1c
5284cae
 
 
 
 
 
 
 
 
 
0d67e1c
 
 
 
 
5284cae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
5284cae
 
0d67e1c
 
5284cae
 
 
 
0d67e1c
5284cae
 
 
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
 
5284cae
0d67e1c
 
 
 
 
 
5284cae
 
0d67e1c
5284cae
 
0d67e1c
 
5284cae
 
 
 
0d67e1c
5284cae
 
0d67e1c
5284cae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
 
 
 
 
 
 
5284cae
0d67e1c
 
5284cae
 
 
0d67e1c
 
 
 
3e2ec65
0d67e1c
5284cae
 
0d67e1c
 
5284cae
0d67e1c
 
 
5284cae
0d67e1c
 
 
 
5284cae
0d67e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d71f56
0d67e1c
 
6d71f56
0d67e1c
 
 
6d71f56
0d67e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1b09f9
6d71f56
0d67e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb80d64
0d67e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb51211
 
 
1ae5245
cb51211
1ae5245
cb51211
 
0d67e1c
cb51211
 
 
 
 
 
 
 
 
 
0d67e1c
cb51211
 
0d67e1c
 
cb51211
 
 
 
 
 
 
 
0d67e1c
cb51211
980d8b6
 
 
cb51211
 
 
 
 
 
 
 
0d67e1c
cb51211
 
 
 
 
 
0d67e1c
cb51211
 
 
 
 
 
0d67e1c
 
cb51211
 
 
 
 
 
 
 
 
0d67e1c
 
 
 
 
6d71f56
 
0d67e1c
 
cb51211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae7dd3
cb51211
0d67e1c
8f42556
cb51211
0d67e1c
 
 
 
 
 
 
 
cb51211
0d67e1c
 
 
cb51211
 
 
 
 
 
0d67e1c
 
cb51211
0d67e1c
cb51211
 
1ae5245
0d67e1c
 
 
 
 
 
 
 
 
 
 
 
cb51211
 
 
 
 
 
 
 
 
 
 
 
0d67e1c
cb51211
 
 
0d67e1c
cb51211
3a50655
cb51211
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
"""
Biomass Prediction Gradio App with Exact 99 Features
Author: najahpokkiri
Date: 2025-05-19

Updated with side-by-side RGB comparison, fixed sample image loading, and corrected biomass calculation.
"""
import os
import sys
import torch
import numpy as np
import gradio as gr
import joblib
import tempfile
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from PIL import Image
import io
import logging
from huggingface_hub import hf_hub_download

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Import model architecture
from model import StableResNet

# Import feature engineering
from feature_engineering import extract_all_features

# Import config - this must happen before loading model_package.pkl
try:
    from config import BiomassPipelineConfig
    logger.info("Successfully imported config.BiomassPipelineConfig")
except ImportError as e:
    logger.error(f"Failed to import config.BiomassPipelineConfig: {e}")
    logger.error("This will likely cause errors when loading the model package")

class BiomassPredictorApp:
    """Gradio app for biomass prediction from satellite imagery"""
    
    def __init__(self, model_repo="pokkiri/biomass-model"):
        """Initialize the app with model repository information"""
        self.model = None
        self.package = None
        self.feature_names = []
        self.model_repo = model_repo
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Cache for storing temporary files
        self.temp_files = []
        
        # Load the model
        self.load_model()
    
    def load_model(self):
        """Load the model and preprocessing pipeline from HuggingFace Hub"""
        try:
            logger.info(f"Loading model from {self.model_repo}")
            
            # Download model files from HuggingFace
            model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
            package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
            
            try:
                # Try to load package with metadata
                logger.info(f"Loading package from {package_path}")
                self.package = joblib.load(package_path)
                logger.info("Successfully loaded model package")
                
                # Extract information from package
                n_features = self.package['n_features']
                self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)])
                
                logger.info(f"Package keys: {list(self.package.keys())}")
                logger.info(f"Model expects {n_features} features")
                
                # Verify feature count is 99
                if n_features != 99:
                    logger.warning(f"Warning: Model expects {n_features} features, not the expected 99. This may cause issues.")
                
            except Exception as e:
                logger.error(f"Error loading package file: {e}")
                # Fallback to default values
                n_features = 99  # We know there are 99 features
                self.feature_names = [f"feature_{i}" for i in range(n_features)]
                
                # Create a minimal package with essential components
                self.package = {
                    'n_features': n_features,
                    'use_log_transform': True,
                    'epsilon': 1.0,
                    'scaler': None  # Will handle the None case in prediction
                }
            
            # Initialize model
            self.model = StableResNet(n_features=n_features)
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            self.model.to(self.device)
            self.model.eval()
            
            logger.info(f"Model loaded successfully from {self.model_repo}")
            logger.info(f"Number of features: {n_features}")
            logger.info(f"Using device: {self.device}")
            logger.info(f"Log transform: {self.package.get('use_log_transform', True)}")
            logger.info(f"Epsilon: {self.package.get('epsilon', 1.0)}")
            
            return True
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return False
    
    def cleanup(self):
        """Clean up temporary files"""
        for tmp_path in self.temp_files:
            try:
                if os.path.exists(tmp_path):
                    os.unlink(tmp_path)
            except Exception as e:
                logger.warning(f"Failed to remove temporary file {tmp_path}: {e}")
        
        self.temp_files = []
    
    def load_sample_image(self):
        """Load the sample image and return a file-like object"""
        try:
            sample_path = "input_chip_1.tif"
            if os.path.exists(sample_path):
                logger.info(f"Loading sample image from {sample_path}")
                return sample_path
            else:
                logger.warning(f"Sample image not found at {sample_path}")
                return None
        except Exception as e:
            logger.error(f"Error loading sample image: {e}")
            return None
    
    def predict_biomass(self, image_file, display_type="heatmap"):
        """Predict biomass from a satellite image"""
        if self.model is None:
            return None, "Error: Model not loaded. Please check logs for details."
        
        if image_file is None:
            return None, "Error: No file uploaded. Please upload a GeoTIFF file or use the sample image."
        
        try:
            # Check if we're using the sample image (string path) or an uploaded file
            if isinstance(image_file, str):
                logger.info(f"Using sample image: {image_file}")
                tmp_path = image_file  # Use the sample path directly
                cleanup_tmp = False  # Don't delete the sample file
            else:
                # Create a temporary file to save the uploaded file
                with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
                    tmp_path = tmp_file.name
                    with open(image_file.name, 'rb') as f:
                        tmp_file.write(f.read())
                
                # Add to list for cleanup later
                self.temp_files.append(tmp_path)
                cleanup_tmp = True
            
            # Ensure rasterio is available
            try:
                import rasterio
            except ImportError:
                return None, "Error: rasterio is required but not installed. Please install with: pip install rasterio"
            
            # Open the image file
            with rasterio.open(tmp_path) as src:
                image = src.read()
                height, width = image.shape[1], image.shape[2]
                transform = src.transform
                crs = src.crs
                
                # Check if we need to limit to 59 bands
                if image.shape[0] > 59:
                    logger.info(f"Image has {image.shape[0]} bands, selecting first 59 for model compatibility")
                    image = image[:59, :, :]
                
                logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
                
                # Validate minimum band count
                if image.shape[0] < 1:
                    return None, f"Error: Image has no bands. Please use multi-band satellite imagery."
                
                # Generate all features using feature engineering
                logger.info("Generating all 99 features from bands...")
                feature_matrix, valid_mask, generated_features = extract_all_features(image)
                
                # Print basic feature statistics for debugging
                logger.info(f"Feature statistics - Min: {np.min(feature_matrix, axis=0)[:5]}, " +
                          f"Max: {np.max(feature_matrix, axis=0)[:5]}, " + 
                          f"Mean: {np.mean(feature_matrix, axis=0)[:5]}")
                
                # Verify we have exactly 99 features
                if feature_matrix.shape[1] != 99:
                    logger.error(f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99.")
                    return None, f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99."
                
                # Apply feature scaling if available
                try:
                    if 'scaler' in self.package and self.package['scaler'] is not None:
                        logger.info("Applying feature scaling...")
                        feature_matrix = self.package['scaler'].transform(feature_matrix)
                        logger.info("Scaling complete")
                        logger.info(f"After scaling - Min: {np.min(feature_matrix, axis=0)[:5]}, " +
                                  f"Max: {np.max(feature_matrix, axis=0)[:5]}")
                except Exception as e:
                    logger.warning(f"Error applying scaler: {e}. Using original features.")
                
                # Initialize predictions array
                predictions = np.zeros((height, width), dtype=np.float32)
                
                # Get valid pixel coordinates
                valid_y, valid_x = np.where(valid_mask)
                
                # Make predictions
                logger.info(f"Running model inference on {len(valid_y)} valid pixels...")
                with torch.no_grad():
                    # Process in batches to avoid memory issues
                    batch_size = 10000
                    for i in range(0, len(valid_y), batch_size):
                        end_idx = min(i + batch_size, len(valid_y))
                        batch = feature_matrix[i:end_idx]
                        
                        # Convert to tensor
                        batch_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device)
                        
                        # Get predictions
                        batch_predictions = self.model(batch_tensor).cpu().numpy()
                        
                        # Handle scalar case for single-item batches
                        if batch_predictions.ndim == 0:
                            batch_predictions = np.array([batch_predictions])
                        
                        # Log raw predictions
                        if i == 0:
                            logger.info(f"Raw prediction sample: {batch_predictions[:5]}")
                        
                        # Fix: Correct log transform reversal
                        if self.package.get('use_log_transform', True):
                            # Get epsilon value, default to 1.0
                            epsilon = self.package.get('epsilon', 1.0)
                            
                            # Log transform should be exp(x) - epsilon
                            batch_predictions = np.exp(batch_predictions)
                            
                            # Only subtract epsilon if it's not zero or close to zero
                            if abs(epsilon) > 1e-10:
                                batch_predictions = batch_predictions - epsilon
                            
                            # Ensure non-negative
                            batch_predictions = np.maximum(batch_predictions, 0)
                        
                        # Log transformed predictions
                        if i == 0:
                            logger.info(f"Transformed prediction sample: {batch_predictions[:5]}")
                            logger.info(f"Using log transform: {self.package.get('use_log_transform', True)}, " +
                                      f"epsilon: {self.package.get('epsilon', 1.0)}")
                        
                        # Map predictions back to image
                        for j, pred in enumerate(batch_predictions):
                            y_idx = valid_y[i + j]
                            x_idx = valid_x[i + j]
                            predictions[y_idx, x_idx] = pred
                        
                        # Log progress
                        if (i // batch_size) % 5 == 0 or end_idx == len(valid_y):
                            logger.info(f"Processed {end_idx}/{len(valid_y)} pixels")
                
                # Calculate and log prediction statistics
                valid_predictions = predictions[valid_mask]
                logger.info(f"Prediction statistics - Min: {np.min(valid_predictions):.2f}, " +
                          f"Max: {np.max(valid_predictions):.2f}, " +
                          f"Mean: {np.mean(valid_predictions):.2f}, " + 
                          f"Median: {np.median(valid_predictions):.2f}")
                
                # Create visualization
                logger.info("Creating visualization...")
                
                if display_type == "heatmap":
                    # Create heatmap visualization
                    fig, ax = plt.subplots(figsize=(10, 8))
                    
                    # Use masked array for better visualization
                    masked_predictions = np.ma.masked_where(~valid_mask, predictions)
                    
                    # Set min/max values based on percentiles for better contrast
                    vmin = np.percentile(predictions[valid_mask], 1)
                    vmax = np.percentile(predictions[valid_mask], 99)
                    
                    im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
                    fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)')
                    ax.set_title('Predicted Above-Ground Biomass')
                    ax.axis('off')  # Hide axes for cleaner visualization
                    
                elif display_type == "rgb_overlay":
                    # Create side-by-side comparison (RGB and Biomass)
                    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
                    
                    # Prepare RGB image using bands 4,3,2 (0-indexed: 3,2,1)
                    rgb_bands = [3, 2, 1]  # Using 4,3,2 for RGB (0-indexed)
                    
                    if image.shape[0] >= 5:  # Ensure we have enough bands (need at least 5 for 0-indexed band 4)
                        # Create RGB image
                        rgb = np.zeros((height, width, 3), dtype=np.float32)
                        for i, band_idx in enumerate(rgb_bands):
                            if band_idx < image.shape[0]:
                                rgb[:, :, i] = image[band_idx]
                        
                        # Handle potential NaN values
                        rgb = np.nan_to_num(rgb)
                        
                        # Enhance contrast with percentile-based normalization
                        for i in range(3):
                            p2 = np.percentile(rgb[:,:,i], 2)
                            p98 = np.percentile(rgb[:,:,i], 98)
                            if p98 > p2:
                                rgb[:,:,i] = np.clip((rgb[:,:,i] - p2) / (p98 - p2), 0, 1)
                        
                        # Display RGB image
                        ax1.imshow(rgb)
                        ax1.set_title('RGB Image (Bands 4,3,2)')
                        ax1.axis('off')
                        
                        # Display biomass prediction
                        masked_predictions = np.ma.masked_where(~valid_mask, predictions)
                        vmin = np.percentile(predictions[valid_mask], 1)
                        vmax = np.percentile(predictions[valid_mask], 99)
                        
                        im = ax2.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
                        fig.colorbar(im, ax=ax2, label='Biomass (Mg/ha)')
                        ax2.set_title('Predicted Biomass')
                        ax2.axis('off')
                        
                        # Add super title
                        plt.suptitle('RGB Image and Biomass Prediction', fontsize=16)
                        plt.tight_layout()
                    else:
                        # Fallback to heatmap if not enough bands
                        logger.warning(f"Not enough bands for RGB display (need 5, got {image.shape[0]}). Showing biomass only.")
                        masked_predictions = np.ma.masked_where(~valid_mask, predictions)
                        im = ax1.imshow(masked_predictions, cmap='viridis')
                        fig.colorbar(im, ax=ax1, label='Biomass (Mg/ha)')
                        ax1.set_title('Predicted Above-Ground Biomass')
                        ax1.axis('off')
                
                # Save figure to bytes buffer
                buf = io.BytesIO()
                fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
                buf.seek(0)
                plt.close(fig)
                
                # Calculate summary statistics
                valid_predictions = predictions[valid_mask]
                stats = {
                    'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
                    'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
                    'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
                    'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
                }
                
                # Add area and total biomass if transform is available
                if transform is not None:
                    pixel_area_m2 = abs(transform[0] * transform[4])  # Assuming square pixels
                    total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000)  # Convert to hectares
                    area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000)
                    
                    stats['Total Biomass'] = f"{total_biomass:.2f} Mg"
                    stats['Area'] = f"{area_hectares:.2f} hectares"
                
                # Format statistics as markdown
                stats_md = "### Biomass Statistics\n\n"
                stats_md += "| Metric | Value |\n|--------|-------|\n"
                for k, v in stats.items():
                    stats_md += f"| {k} | {v} |\n"
                
                # Add processing info
                stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels with {feature_matrix.shape[1]} features*"
                
                # Cleanup temporary files if needed
                if cleanup_tmp:
                    self.cleanup()
                
                # Return visualization and statistics
                return Image.open(buf), stats_md
                
        except Exception as e:
            # Ensure cleanup even on error
            self.cleanup()
            
            import traceback
            logger.error(f"Error predicting biomass: {e}")
            logger.error(traceback.format_exc())
            
            return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details."

    def create_interface(self):
        """Create Gradio interface"""
        with gr.Blocks(title="Biomass Prediction Model") as interface:
            gr.Markdown("# Above-Ground Biomass Prediction")
            gr.Markdown("""
            Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
            
            **Requirements:**
            - Image must be a GeoTIFF with spectral bands
            - For best results, use imagery with at least 59 bands or similar to training data
            """)
            
            with gr.Row():
                with gr.Column(scale=1):
                    input_image = gr.File(
                        label="Upload Satellite Image (GeoTIFF)",
                        file_types=[".tif", ".tiff"]
                    )
                    
                    display_type = gr.Radio(
                        choices=["heatmap", "rgb_overlay"],
                        value="heatmap",
                        label="Display Type"
                    )
                    
                    with gr.Row():
                        submit_btn = gr.Button("Generate Biomass Prediction", variant="primary")
                        sample_btn = gr.Button("Use Sample Image")
                
                with gr.Column(scale=2):
                    output_image = gr.Image(
                        label="Biomass Prediction Map",
                        type="pil"
                    )
                    
                    output_stats = gr.Markdown(
                        label="Statistics"
                    )
            
            with gr.Accordion("About", open=False):
                gr.Markdown("""
                ## About This Model
                
                This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
                
                ### Model Details
                
                - Architecture: StableResNet
                - Input: Multi-spectral satellite imagery
                - Output: Above-ground biomass (Mg/ha)
                - Creator: vertify.earth for GIZ Forest Forward
                - Date: 2025-05-19
                
                ### How It Works
                
                1. The model extracts features from each pixel in the satellite image
                2. These features include spectral bands, vegetation indices, texture metrics, and more
                3. The model outputs a biomass prediction for each pixel
                4. Results are visualized as a heatmap or RGB overlay
                
                ### Updates in This Version
                
                - Fixed biomass value calculation issue (improved log transform handling)
                - Added detailed diagnostics for troubleshooting
                - Enhanced RGB visualization with band verification
                """)
            
            # Add a warning if model failed to load
            if self.model is None:
                gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
            
            # Connect the submit button
            submit_btn.click(
                fn=self.predict_biomass,
                inputs=[input_image, display_type],
                outputs=[output_image, output_stats]
            )
            
            # Handle sample image button
            def use_sample_image(display_type):
                sample_path = self.load_sample_image()
                if sample_path is None:
                    return None, "Error: Sample image not found. Please make sure 'input_chip_1.tif' exists in the app directory."
                return self.predict_biomass(sample_path, display_type)
            
            sample_btn.click(
                fn=use_sample_image,
                inputs=[display_type],
                outputs=[output_image, output_stats]
            )
        
        return interface

def launch_app():
    """Launch the Gradio app"""
    try:
        # Create app instance
        app = BiomassPredictorApp()
        
        # Create interface
        interface = app.create_interface()
        
        # Launch interface - Important: no share=True in Hugging Face Spaces
        interface.launch()
    except Exception as e:
        logger.error(f"Error launching app: {e}")
        import traceback
        logger.error(traceback.format_exc())

if __name__ == "__main__":
    launch_app()