Spaces:
Running
Running
""" | |
Biomass Prediction Gradio App with Exact 99 Features | |
Author: najahpokkiri | |
Date: 2025-05-19 | |
Updated with side-by-side RGB comparison, fixed sample image loading, and corrected biomass calculation. | |
""" | |
import os | |
import sys | |
import torch | |
import numpy as np | |
import gradio as gr | |
import joblib | |
import tempfile | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as colors | |
from PIL import Image | |
import io | |
import logging | |
from huggingface_hub import hf_hub_download | |
# Configure logger | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Import model architecture | |
from model import StableResNet | |
# Import feature engineering | |
from feature_engineering import extract_all_features | |
# Import config - this must happen before loading model_package.pkl | |
try: | |
from config import BiomassPipelineConfig | |
logger.info("Successfully imported config.BiomassPipelineConfig") | |
except ImportError as e: | |
logger.error(f"Failed to import config.BiomassPipelineConfig: {e}") | |
logger.error("This will likely cause errors when loading the model package") | |
class BiomassPredictorApp: | |
"""Gradio app for biomass prediction from satellite imagery""" | |
def __init__(self, model_repo="pokkiri/biomass-model"): | |
"""Initialize the app with model repository information""" | |
self.model = None | |
self.package = None | |
self.feature_names = [] | |
self.model_repo = model_repo | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Cache for storing temporary files | |
self.temp_files = [] | |
# Load the model | |
self.load_model() | |
def load_model(self): | |
"""Load the model and preprocessing pipeline from HuggingFace Hub""" | |
try: | |
logger.info(f"Loading model from {self.model_repo}") | |
# Download model files from HuggingFace | |
model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt") | |
package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl") | |
try: | |
# Try to load package with metadata | |
logger.info(f"Loading package from {package_path}") | |
self.package = joblib.load(package_path) | |
logger.info("Successfully loaded model package") | |
# Extract information from package | |
n_features = self.package['n_features'] | |
self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)]) | |
logger.info(f"Package keys: {list(self.package.keys())}") | |
logger.info(f"Model expects {n_features} features") | |
# Verify feature count is 99 | |
if n_features != 99: | |
logger.warning(f"Warning: Model expects {n_features} features, not the expected 99. This may cause issues.") | |
except Exception as e: | |
logger.error(f"Error loading package file: {e}") | |
# Fallback to default values | |
n_features = 99 # We know there are 99 features | |
self.feature_names = [f"feature_{i}" for i in range(n_features)] | |
# Create a minimal package with essential components | |
self.package = { | |
'n_features': n_features, | |
'use_log_transform': True, | |
'epsilon': 1.0, | |
'scaler': None # Will handle the None case in prediction | |
} | |
# Initialize model | |
self.model = StableResNet(n_features=n_features) | |
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
self.model.to(self.device) | |
self.model.eval() | |
logger.info(f"Model loaded successfully from {self.model_repo}") | |
logger.info(f"Number of features: {n_features}") | |
logger.info(f"Using device: {self.device}") | |
logger.info(f"Log transform: {self.package.get('use_log_transform', True)}") | |
logger.info(f"Epsilon: {self.package.get('epsilon', 1.0)}") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
return False | |
def cleanup(self): | |
"""Clean up temporary files""" | |
for tmp_path in self.temp_files: | |
try: | |
if os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |
except Exception as e: | |
logger.warning(f"Failed to remove temporary file {tmp_path}: {e}") | |
self.temp_files = [] | |
def load_sample_image(self): | |
"""Load the sample image and return a file-like object""" | |
try: | |
sample_path = "input_chip_1.tif" | |
if os.path.exists(sample_path): | |
logger.info(f"Loading sample image from {sample_path}") | |
return sample_path | |
else: | |
logger.warning(f"Sample image not found at {sample_path}") | |
return None | |
except Exception as e: | |
logger.error(f"Error loading sample image: {e}") | |
return None | |
def predict_biomass(self, image_file, display_type="heatmap"): | |
"""Predict biomass from a satellite image""" | |
if self.model is None: | |
return None, "Error: Model not loaded. Please check logs for details." | |
if image_file is None: | |
return None, "Error: No file uploaded. Please upload a GeoTIFF file or use the sample image." | |
try: | |
# Check if we're using the sample image (string path) or an uploaded file | |
if isinstance(image_file, str): | |
logger.info(f"Using sample image: {image_file}") | |
tmp_path = image_file # Use the sample path directly | |
cleanup_tmp = False # Don't delete the sample file | |
else: | |
# Create a temporary file to save the uploaded file | |
with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file: | |
tmp_path = tmp_file.name | |
with open(image_file.name, 'rb') as f: | |
tmp_file.write(f.read()) | |
# Add to list for cleanup later | |
self.temp_files.append(tmp_path) | |
cleanup_tmp = True | |
# Ensure rasterio is available | |
try: | |
import rasterio | |
except ImportError: | |
return None, "Error: rasterio is required but not installed. Please install with: pip install rasterio" | |
# Open the image file | |
with rasterio.open(tmp_path) as src: | |
image = src.read() | |
height, width = image.shape[1], image.shape[2] | |
transform = src.transform | |
crs = src.crs | |
# Check if we need to limit to 59 bands | |
if image.shape[0] > 59: | |
logger.info(f"Image has {image.shape[0]} bands, selecting first 59 for model compatibility") | |
image = image[:59, :, :] | |
logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands") | |
# Validate minimum band count | |
if image.shape[0] < 1: | |
return None, f"Error: Image has no bands. Please use multi-band satellite imagery." | |
# Generate all features using feature engineering | |
logger.info("Generating all 99 features from bands...") | |
feature_matrix, valid_mask, generated_features = extract_all_features(image) | |
# Print basic feature statistics for debugging | |
logger.info(f"Feature statistics - Min: {np.min(feature_matrix, axis=0)[:5]}, " + | |
f"Max: {np.max(feature_matrix, axis=0)[:5]}, " + | |
f"Mean: {np.mean(feature_matrix, axis=0)[:5]}") | |
# Verify we have exactly 99 features | |
if feature_matrix.shape[1] != 99: | |
logger.error(f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99.") | |
return None, f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99." | |
# Apply feature scaling if available | |
try: | |
if 'scaler' in self.package and self.package['scaler'] is not None: | |
logger.info("Applying feature scaling...") | |
feature_matrix = self.package['scaler'].transform(feature_matrix) | |
logger.info("Scaling complete") | |
logger.info(f"After scaling - Min: {np.min(feature_matrix, axis=0)[:5]}, " + | |
f"Max: {np.max(feature_matrix, axis=0)[:5]}") | |
except Exception as e: | |
logger.warning(f"Error applying scaler: {e}. Using original features.") | |
# Initialize predictions array | |
predictions = np.zeros((height, width), dtype=np.float32) | |
# Get valid pixel coordinates | |
valid_y, valid_x = np.where(valid_mask) | |
# Make predictions | |
logger.info(f"Running model inference on {len(valid_y)} valid pixels...") | |
with torch.no_grad(): | |
# Process in batches to avoid memory issues | |
batch_size = 10000 | |
for i in range(0, len(valid_y), batch_size): | |
end_idx = min(i + batch_size, len(valid_y)) | |
batch = feature_matrix[i:end_idx] | |
# Convert to tensor | |
batch_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device) | |
# Get predictions | |
batch_predictions = self.model(batch_tensor).cpu().numpy() | |
# Handle scalar case for single-item batches | |
if batch_predictions.ndim == 0: | |
batch_predictions = np.array([batch_predictions]) | |
# Log raw predictions | |
if i == 0: | |
logger.info(f"Raw prediction sample: {batch_predictions[:5]}") | |
# Fix: Correct log transform reversal | |
if self.package.get('use_log_transform', True): | |
# Get epsilon value, default to 1.0 | |
epsilon = self.package.get('epsilon', 1.0) | |
# Log transform should be exp(x) - epsilon | |
batch_predictions = np.exp(batch_predictions) | |
# Only subtract epsilon if it's not zero or close to zero | |
if abs(epsilon) > 1e-10: | |
batch_predictions = batch_predictions - epsilon | |
# Ensure non-negative | |
batch_predictions = np.maximum(batch_predictions, 0) | |
# Log transformed predictions | |
if i == 0: | |
logger.info(f"Transformed prediction sample: {batch_predictions[:5]}") | |
logger.info(f"Using log transform: {self.package.get('use_log_transform', True)}, " + | |
f"epsilon: {self.package.get('epsilon', 1.0)}") | |
# Map predictions back to image | |
for j, pred in enumerate(batch_predictions): | |
y_idx = valid_y[i + j] | |
x_idx = valid_x[i + j] | |
predictions[y_idx, x_idx] = pred | |
# Log progress | |
if (i // batch_size) % 5 == 0 or end_idx == len(valid_y): | |
logger.info(f"Processed {end_idx}/{len(valid_y)} pixels") | |
# Calculate and log prediction statistics | |
valid_predictions = predictions[valid_mask] | |
logger.info(f"Prediction statistics - Min: {np.min(valid_predictions):.2f}, " + | |
f"Max: {np.max(valid_predictions):.2f}, " + | |
f"Mean: {np.mean(valid_predictions):.2f}, " + | |
f"Median: {np.median(valid_predictions):.2f}") | |
# Create visualization | |
logger.info("Creating visualization...") | |
if display_type == "heatmap": | |
# Create heatmap visualization | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
# Use masked array for better visualization | |
masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
# Set min/max values based on percentiles for better contrast | |
vmin = np.percentile(predictions[valid_mask], 1) | |
vmax = np.percentile(predictions[valid_mask], 99) | |
im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax) | |
fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)') | |
ax.set_title('Predicted Above-Ground Biomass') | |
ax.axis('off') # Hide axes for cleaner visualization | |
elif display_type == "rgb_overlay": | |
# Create side-by-side comparison (RGB and Biomass) | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) | |
# Prepare RGB image using bands 4,3,2 (0-indexed: 3,2,1) | |
rgb_bands = [3, 2, 1] # Using 4,3,2 for RGB (0-indexed) | |
if image.shape[0] >= 5: # Ensure we have enough bands (need at least 5 for 0-indexed band 4) | |
# Create RGB image | |
rgb = np.zeros((height, width, 3), dtype=np.float32) | |
for i, band_idx in enumerate(rgb_bands): | |
if band_idx < image.shape[0]: | |
rgb[:, :, i] = image[band_idx] | |
# Handle potential NaN values | |
rgb = np.nan_to_num(rgb) | |
# Enhance contrast with percentile-based normalization | |
for i in range(3): | |
p2 = np.percentile(rgb[:,:,i], 2) | |
p98 = np.percentile(rgb[:,:,i], 98) | |
if p98 > p2: | |
rgb[:,:,i] = np.clip((rgb[:,:,i] - p2) / (p98 - p2), 0, 1) | |
# Display RGB image | |
ax1.imshow(rgb) | |
ax1.set_title('RGB Image (Bands 4,3,2)') | |
ax1.axis('off') | |
# Display biomass prediction | |
masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
vmin = np.percentile(predictions[valid_mask], 1) | |
vmax = np.percentile(predictions[valid_mask], 99) | |
im = ax2.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax) | |
fig.colorbar(im, ax=ax2, label='Biomass (Mg/ha)') | |
ax2.set_title('Predicted Biomass') | |
ax2.axis('off') | |
# Add super title | |
plt.suptitle('RGB Image and Biomass Prediction', fontsize=16) | |
plt.tight_layout() | |
else: | |
# Fallback to heatmap if not enough bands | |
logger.warning(f"Not enough bands for RGB display (need 5, got {image.shape[0]}). Showing biomass only.") | |
masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
im = ax1.imshow(masked_predictions, cmap='viridis') | |
fig.colorbar(im, ax=ax1, label='Biomass (Mg/ha)') | |
ax1.set_title('Predicted Above-Ground Biomass') | |
ax1.axis('off') | |
# Save figure to bytes buffer | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
plt.close(fig) | |
# Calculate summary statistics | |
valid_predictions = predictions[valid_mask] | |
stats = { | |
'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha", | |
'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha", | |
'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha", | |
'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha" | |
} | |
# Add area and total biomass if transform is available | |
if transform is not None: | |
pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels | |
total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares | |
area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000) | |
stats['Total Biomass'] = f"{total_biomass:.2f} Mg" | |
stats['Area'] = f"{area_hectares:.2f} hectares" | |
# Format statistics as markdown | |
stats_md = "### Biomass Statistics\n\n" | |
stats_md += "| Metric | Value |\n|--------|-------|\n" | |
for k, v in stats.items(): | |
stats_md += f"| {k} | {v} |\n" | |
# Add processing info | |
stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels with {feature_matrix.shape[1]} features*" | |
# Cleanup temporary files if needed | |
if cleanup_tmp: | |
self.cleanup() | |
# Return visualization and statistics | |
return Image.open(buf), stats_md | |
except Exception as e: | |
# Ensure cleanup even on error | |
self.cleanup() | |
import traceback | |
logger.error(f"Error predicting biomass: {e}") | |
logger.error(traceback.format_exc()) | |
return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details." | |
def create_interface(self): | |
"""Create Gradio interface""" | |
with gr.Blocks(title="Biomass Prediction Model") as interface: | |
gr.Markdown("# Above-Ground Biomass Prediction") | |
gr.Markdown(""" | |
Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape. | |
**Requirements:** | |
- Image must be a GeoTIFF with spectral bands | |
- For best results, use imagery with at least 59 bands or similar to training data | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.File( | |
label="Upload Satellite Image (GeoTIFF)", | |
file_types=[".tif", ".tiff"] | |
) | |
display_type = gr.Radio( | |
choices=["heatmap", "rgb_overlay"], | |
value="heatmap", | |
label="Display Type" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate Biomass Prediction", variant="primary") | |
sample_btn = gr.Button("Use Sample Image") | |
with gr.Column(scale=2): | |
output_image = gr.Image( | |
label="Biomass Prediction Map", | |
type="pil" | |
) | |
output_stats = gr.Markdown( | |
label="Statistics" | |
) | |
with gr.Accordion("About", open=False): | |
gr.Markdown(""" | |
## About This Model | |
This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery. | |
### Model Details | |
- Architecture: StableResNet | |
- Input: Multi-spectral satellite imagery | |
- Output: Above-ground biomass (Mg/ha) | |
- Creator: vertify.earth for GIZ Forest Forward | |
- Date: 2025-05-19 | |
### How It Works | |
1. The model extracts features from each pixel in the satellite image | |
2. These features include spectral bands, vegetation indices, texture metrics, and more | |
3. The model outputs a biomass prediction for each pixel | |
4. Results are visualized as a heatmap or RGB overlay | |
### Updates in This Version | |
- Fixed biomass value calculation issue (improved log transform handling) | |
- Added detailed diagnostics for troubleshooting | |
- Enhanced RGB visualization with band verification | |
""") | |
# Add a warning if model failed to load | |
if self.model is None: | |
gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.") | |
# Connect the submit button | |
submit_btn.click( | |
fn=self.predict_biomass, | |
inputs=[input_image, display_type], | |
outputs=[output_image, output_stats] | |
) | |
# Handle sample image button | |
def use_sample_image(display_type): | |
sample_path = self.load_sample_image() | |
if sample_path is None: | |
return None, "Error: Sample image not found. Please make sure 'input_chip_1.tif' exists in the app directory." | |
return self.predict_biomass(sample_path, display_type) | |
sample_btn.click( | |
fn=use_sample_image, | |
inputs=[display_type], | |
outputs=[output_image, output_stats] | |
) | |
return interface | |
def launch_app(): | |
"""Launch the Gradio app""" | |
try: | |
# Create app instance | |
app = BiomassPredictorApp() | |
# Create interface | |
interface = app.create_interface() | |
# Launch interface - Important: no share=True in Hugging Face Spaces | |
interface.launch() | |
except Exception as e: | |
logger.error(f"Error launching app: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
if __name__ == "__main__": | |
launch_app() |