Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import folium | |
from folium import plugins | |
import geopandas as gpd | |
import rasterio | |
from rasterio.warp import transform_bounds | |
import json | |
import tempfile | |
import shutil | |
import uuid | |
import logging | |
import traceback | |
import numpy as np | |
from PIL import Image | |
# Configure logging for HF Spaces | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger('forestai') | |
# ================================ | |
# CONFIGURATIONS | |
# ================================ | |
# Feature styles for trees only | |
FEATURE_STYLES = { | |
'trees': {"color": "yellow", "fillColor": "yellow", "fillOpacity": 0.3, "weight": 2} | |
} | |
# Example file path | |
EXAMPLE_FILE_PATH = "example.tif" | |
# ================================ | |
# TEMP DIRECTORY SETUP | |
# ================================ | |
def setup_temp_dirs(): | |
"""Create temporary directories.""" | |
temp_base = tempfile.mkdtemp(prefix="forestai_") | |
dirs = { | |
'uploads': os.path.join(temp_base, 'uploads'), | |
'processed': os.path.join(temp_base, 'processed'), | |
'static': os.path.join(temp_base, 'static') | |
} | |
for dir_path in dirs.values(): | |
os.makedirs(dir_path, exist_ok=True) | |
return dirs | |
# Global temp directories | |
TEMP_DIRS = setup_temp_dirs() | |
# ================================ | |
# CORE FUNCTIONS | |
# ================================ | |
def get_bounds_from_geotiff(geotiff_path): | |
"""Extract bounds from GeoTIFF and convert to WGS84.""" | |
try: | |
with rasterio.open(geotiff_path) as src: | |
bounds = src.bounds | |
if src.crs: | |
west, south, east, north = transform_bounds( | |
src.crs, 'EPSG:4326', | |
bounds.left, bounds.bottom, bounds.right, bounds.top | |
) | |
return west, south, east, north | |
else: | |
return -74.1, 40.6, -73.9, 40.8 | |
except Exception as e: | |
logger.error(f"Error extracting bounds: {str(e)}") | |
return -74.1, 40.6, -73.9, 40.8 | |
def create_split_view_map(geojson_data, bounds): | |
"""Create split-view map with detected trees.""" | |
try: | |
west, south, east, north = bounds | |
center = [(south + north) / 2, (west + east) / 2] | |
# Calculate zoom level | |
lat_diff = north - south | |
lon_diff = east - west | |
max_diff = max(lat_diff, lon_diff) | |
if max_diff < 0.01: | |
zoom = 16 | |
elif max_diff < 0.05: | |
zoom = 14 | |
elif max_diff < 0.1: | |
zoom = 12 | |
else: | |
zoom = 10 | |
# Create base map | |
m = folium.Map(location=center, zoom_start=zoom) | |
# Create tile layers | |
left_layer = folium.TileLayer( | |
tiles='OpenStreetMap', | |
name='OpenStreetMap', | |
overlay=False, | |
control=False | |
) | |
right_layer = folium.TileLayer( | |
tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}', | |
attr='Google Satellite', | |
name='Google Satellite', | |
overlay=False, | |
control=False | |
) | |
left_layer.add_to(m) | |
right_layer.add_to(m) | |
# Add detected trees | |
if geojson_data and 'features' in geojson_data and geojson_data['features']: | |
style = FEATURE_STYLES['trees'] | |
geojson_layer = folium.GeoJson( | |
geojson_data, | |
name='Detected Trees', | |
style_function=lambda x: style, | |
popup=folium.GeoJsonPopup( | |
fields=['confidence'] if 'confidence' in str(geojson_data) else [], | |
aliases=['Confidence:'] if 'confidence' in str(geojson_data) else [], | |
localize=True | |
) | |
) | |
geojson_layer.add_to(m) | |
# Add red outline for the region bounding box | |
folium.Rectangle( | |
bounds=[[south, west], [north, east]], | |
color='red', | |
weight=3, | |
fill=False | |
).add_to(m) | |
# Add split view plugin | |
plugins.SideBySideLayers( | |
layer_left=left_layer, | |
layer_right=right_layer | |
).add_to(m) | |
# Add layer control | |
folium.LayerControl().add_to(m) | |
# Fit bounds | |
m.fit_bounds([[south, west], [north, east]], padding=(20, 20)) | |
return m | |
except Exception as e: | |
logger.error(f"Error creating map: {str(e)}") | |
# Return basic map on error | |
m = folium.Map(location=[40.7, -74.0], zoom_start=10) | |
return m | |
def process_geotiff_file(geotiff_file): | |
"""Process uploaded GeoTIFF file for tree detection.""" | |
if geotiff_file is None: | |
return None, "Please upload a GeoTIFF file or use the example file" | |
try: | |
# Create unique ID | |
unique_id = str(uuid.uuid4().hex)[:8] | |
# Handle file upload | |
if hasattr(geotiff_file, 'name'): | |
filename = os.path.basename(geotiff_file.name) | |
else: | |
filename = os.path.basename(geotiff_file) | |
# Save uploaded file | |
geotiff_path = os.path.join(TEMP_DIRS['uploads'], f"{unique_id}_{filename}") | |
if hasattr(geotiff_file, 'read'): | |
file_content = geotiff_file.read() | |
with open(geotiff_path, "wb") as f: | |
f.write(file_content) | |
else: | |
shutil.copy(geotiff_file, geotiff_path) | |
logger.info(f"File saved to {geotiff_path}") | |
# Import and extract features | |
from utils.advanced_extraction import extract_features_from_geotiff | |
logger.info("Extracting tree features...") | |
geojson_data = extract_features_from_geotiff(geotiff_path, TEMP_DIRS['processed'], "trees") | |
if not geojson_data or not geojson_data.get('features'): | |
return None, "No trees detected in the image" | |
# Get bounds and create map | |
bounds = get_bounds_from_geotiff(geotiff_path) | |
map_obj = create_split_view_map(geojson_data, bounds) | |
if map_obj: | |
# Save map | |
html_path = os.path.join(TEMP_DIRS['static'], f"map_{unique_id}.html") | |
map_obj.save(html_path) | |
# Read HTML content | |
with open(html_path, 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
# Create iframe | |
iframe_html = f''' | |
<div style="width:100%; height:600px; border:1px solid #ddd; border-radius:8px; overflow:hidden;"> | |
<iframe srcdoc="{html_content.replace('"', '"')}" | |
width="100%" height="600px" style="border:none;"></iframe> | |
</div> | |
''' | |
num_features = len(geojson_data['features']) | |
return iframe_html, f"✅ Detected {num_features} tree areas in {filename}" | |
else: | |
return None, "Failed to create map" | |
except Exception as e: | |
logger.error(f"Error processing file: {str(e)}") | |
return None, f"❌ Error: {str(e)}" | |
def load_example_file(): | |
"""Load the example.tif file and return it for processing.""" | |
try: | |
if os.path.exists(EXAMPLE_FILE_PATH): | |
logger.info("Loading example file...") | |
return EXAMPLE_FILE_PATH | |
else: | |
logger.warning("Example file not found") | |
return None | |
except Exception as e: | |
logger.error(f"Error loading example file: {str(e)}") | |
return None | |
def process_example_file(): | |
"""Process the example file and return results.""" | |
example_file = load_example_file() | |
if example_file: | |
return process_geotiff_file(example_file) | |
else: | |
return None, "❌ Example file (example.tif) not found in the root directory" | |
def check_example_file_exists(): | |
"""Check if example file exists and return appropriate message.""" | |
if os.path.exists(EXAMPLE_FILE_PATH): | |
return f"✅ Example file found: {EXAMPLE_FILE_PATH}" | |
else: | |
return f"⚠️ Example file not found: {EXAMPLE_FILE_PATH}" | |
# ================================ | |
# GRADIO INTERFACE | |
# ================================ | |
def create_gradio_interface(): | |
"""Create the Gradio interface for tree detection.""" | |
css = """ | |
.gradio-container { | |
max-width: 100% !important; | |
width: 100% !important; | |
margin: 0 !important; | |
padding: 10px !important; | |
} | |
.map-container { | |
border-radius: 8px; | |
overflow: hidden; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
width: 100% !important; | |
} | |
body { | |
margin: 0 !important; | |
padding: 0 !important; | |
} | |
.contain { | |
max-width: none !important; | |
padding: 0 !important; | |
} | |
.example-button { | |
background: linear-gradient(135deg, #28a745 0%, #20c997 100%) !important; | |
border: none !important; | |
color: white !important; | |
} | |
""" | |
with gr.Blocks(title="🌲 ForestAI - Tree Detection", css=css, theme=gr.themes.Soft()) as app: | |
# Simple header | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; margin-bottom: 20px;"> | |
<h1 style="color: white; margin: 0; font-size: 2.5em;">🌲 ForestAI</h1> | |
<p style="color: white; margin: 10px 0 0 0; font-size: 1.2em;">Tree Detection from Satellite Imagery</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Upload GeoTIFF File") | |
file_input = gr.File( | |
label="Select GeoTIFF File", | |
file_types=[".tif", ".tiff"], | |
type="filepath" | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button( | |
"🔍 Detect Trees", | |
variant="primary", | |
size="lg", | |
scale=2 | |
) | |
example_btn = gr.Button( | |
"📁 Use Example File", | |
variant="secondary", | |
size="lg", | |
scale=1, | |
elem_classes=["example-button"] | |
) | |
# Example file status | |
example_status = gr.Textbox( | |
label="Example File Status", | |
value=check_example_file_exists(), | |
interactive=False, | |
lines=1 | |
) | |
gr.Markdown("### Status") | |
status_output = gr.Textbox( | |
label="Processing Status", | |
interactive=False, | |
placeholder="Upload a file and click 'Detect Trees' or use the example file...", | |
lines=3 | |
) | |
with gr.Column(scale=2): | |
gr.Markdown("### Results Map") | |
map_output = gr.HTML( | |
value=''' | |
<div style="width:100%; height:600px; border:1px solid #ddd; border-radius:8px; | |
display:flex; align-items:center; justify-content:center; | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);"> | |
<div style="text-align:center; color:#666;"> | |
<h3>🌲 Upload a GeoTIFF file or use example to see detected trees</h3> | |
<p>Interactive map will appear here</p> | |
</div> | |
</div> | |
''', | |
elem_classes=["map-container"] | |
) | |
# Event handlers | |
analyze_btn.click( | |
fn=process_geotiff_file, | |
inputs=[file_input], | |
outputs=[map_output, status_output], | |
show_progress=True | |
) | |
example_btn.click( | |
fn=process_example_file, | |
inputs=[], | |
outputs=[map_output, status_output], | |
show_progress=True | |
) | |
# Simple instructions | |
gr.Markdown(""" | |
### How to Use: | |
1. **Upload** a GeoTIFF satellite image file OR click "Use Example File" to try with the included sample | |
2. **Click** "Detect Trees" to analyze your uploaded image | |
3. **Explore** the interactive map with detected tree areas | |
4. **Use** the split-view slider to compare base map and satellite imagery | |
### Map Controls: | |
- **Split View**: Drag the vertical slider to compare layers | |
- **Zoom**: Scroll to zoom in/out, drag to pan | |
- **Layers**: Use layer control to toggle trees on/off | |
### Example File: | |
- The example file should be named `example.tif` and placed in the same directory as this application | |
- Click "Use Example File" to quickly test the tree detection without uploading your own file | |
""") | |
return app | |
if __name__ == "__main__": | |
logger.info("🌲 Starting ForestAI Tree Detection") | |
app = create_gradio_interface() | |
app.launch() |