import os import sys import torch import numpy as np import matplotlib.pyplot as plt import rasterio from math import radians, sin, cos, acos from huggingface_hub import hf_hub_download import gradio as gr import folium import tempfile import logging from typing import List, Tuple, Optional, Dict import base64 from io import BytesIO # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ImageHandler: """Handles satellite image loading and processing""" @staticmethod def load_image(path, plot=False): with rasterio.open(path) as f: data = f.read().astype(np.float32) image = data / 10000.0 B10 = np.zeros((1, *image.shape[1:]), dtype=image.dtype) image = np.concatenate([image[:10], B10, image[10:]], axis=0) image = torch.tensor(image) if plot: rgb_image = np.stack([ f.read(4) / 10000.0, # Red f.read(3) / 10000.0, # Green f.read(2) / 10000.0 # Blue ], axis=-1) return image, rgb_image return image @staticmethod def plot_rgb(rgb_image) -> str: """Plot RGB image and return base64 encoded string""" plt.figure(figsize=(10, 10)) plt.imshow(rgb_image) plt.title("RGB Sentinel-2 Image") plt.axis('off') # Save to base64 buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close() buf.seek(0) return f"data:image/png;base64,{base64.b64encode(buf.read()).decode()}" class SatCLIPModel: """Wrapper for SatCLIP model""" def __init__(self, model_size: str = 'large'): """Initialize model""" self.device = torch.device('cpu') logger.info(f"Using device: {self.device}") # Model config model_configs = { 'small': ("microsoft/SatCLIP-ResNet18-L40", "satclip-resnet18-l40.ckpt"), 'large': ("microsoft/SatCLIP-ResNet50-L40", "satclip-resnet50-l40.ckpt") } try: # Setup environment if not os.path.exists('satclip'): os.system('git clone https://github.com/microsoft/satclip.git') os.makedirs('satclip/satclip/datamodules', exist_ok=True) open('satclip/satclip/datamodules/__init__.py', 'a').close() if './satclip' not in sys.path: sys.path.append('./satclip') # Load model from satclip.load import get_satclip model_name, checkpoint = model_configs[model_size] checkpoint_path = hf_hub_download(model_name, checkpoint) self.model = get_satclip( checkpoint_path, return_all=True, device=self.device ) self.model.eval() logger.info(f"Model loaded: {model_size}") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def get_location_embedding(self, coords: List[Tuple[float, float]]) -> torch.Tensor: """Get embeddings for coordinates""" locations = torch.tensor(coords, dtype=torch.float32).to(self.device) with torch.no_grad(): return self.model.encode_location(locations) def get_image_embedding(self, image_path: str) -> torch.Tensor: """Get embedding from image""" image = ImageHandler.load_image(image_path) with torch.no_grad(): return self.model.visual(image.unsqueeze(0).to(self.device)) def find_similar_locations(self, query_embedding: torch.Tensor, n_points: int = 1000) -> Tuple[np.ndarray, np.ndarray]: """Find similar locations globally""" # Generate grid lats = np.linspace(-80, 80, int(np.sqrt(n_points))) lons = np.linspace(-180, 180, int(np.sqrt(n_points))) grid_lats, grid_lons = np.meshgrid(lats, lons) locations = torch.tensor( np.stack([grid_lons.flatten(), grid_lats.flatten()], axis=1) ).float().to(self.device) # Get embeddings and similarities with torch.no_grad(): loc_embeddings = self.model.encode_location(locations) similarities = torch.nn.functional.cosine_similarity( query_embedding, loc_embeddings ).cpu().numpy() return locations.cpu().numpy(), similarities def create_similarity_map(locations: np.ndarray, similarities: np.ndarray, reference: Optional[Tuple[float, float]] = None) -> str: """Create interactive folium map""" # Center map if reference: center = [reference[1], reference[0]] else: center = [0, 0] m = folium.Map(location=center, zoom_start=3) # Add reference point if provided if reference: folium.Marker( center, popup='Reference Location', icon=folium.Icon(color='red') ).add_to(m) # Add similarity points for loc, sim in zip(locations, similarities): folium.CircleMarker( [loc[1], loc[0]], radius=5, popup=f"Similarity: {sim:.3f}", color='blue', fill=True, fill_opacity=float((sim + 1) / 2) # Normalize to [0,1] ).add_to(m) # Save map with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as tmp: m.save(tmp.name) return tmp.name # Create Gradio interface def create_interface(): with gr.Blocks(title="SatCLIP Analysis") as iface: gr.Markdown(""" # SatCLIP Geographic Analysis Analyze locations and satellite imagery using SatCLIP embeddings. """) with gr.Tabs(): # Location Analysis Tab with gr.Tab("Location Analysis"): with gr.Row(): with gr.Column(): loc_input = gr.Textbox( label="Enter coordinates (longitude,latitude)", placeholder="-74.006,40.7128", lines=2 ) n_points = gr.Slider( minimum=100, maximum=10000, value=1000, step=100, label="Number of comparison points" ) find_similar_btn = gr.Button("Find Similar Locations") with gr.Column(): map_output = gr.HTML(label="Similarity Map") # Image Analysis Tab with gr.Tab("Image Analysis"): with gr.Row(): with gr.Column(): image_input = gr.File(label="Upload Sentinel-2 image") image_coords = gr.Textbox( label="Image coordinates (if known)", placeholder="-74.006,40.7128" ) analyze_img_btn = gr.Button("Analyze Image") with gr.Column(): image_display = gr.Image(label="RGB Preview") img_map_output = gr.HTML(label="Similar Locations") def analyze_location(coords: str, num_points: int) -> str: try: # Process coordinates lon, lat = map(float, coords.strip().split(',')) if not (-180 <= lon <= 180 and -90 <= lat <= 90): return "Error: Invalid coordinates" # Initialize model and get embedding model = SatCLIPModel() embedding = model.get_location_embedding([(lon, lat)]) # Find similar locations locations, similarities = model.find_similar_locations( embedding, n_points=num_points ) # Create map return create_similarity_map(locations, similarities, (lon, lat)) except Exception as e: logger.error(f"Location analysis error: {str(e)}") return f"Error: {str(e)}" def analyze_image(image_path: str, coords: str = None) -> Tuple[str, str]: try: # Load and display image image, rgb = ImageHandler.load_image(image_path, plot=True) preview = ImageHandler.plot_rgb(rgb) # Get embedding and find similar locations model = SatCLIPModel() embedding = model.get_image_embedding(image_path) locations, similarities = model.find_similar_locations(embedding) # Create map (with reference point if coordinates provided) reference = None if coords: lon, lat = map(float, coords.strip().split(',')) reference = (lon, lat) map_html = create_similarity_map(locations, similarities, reference) return preview, map_html except Exception as e: logger.error(f"Image analysis error: {str(e)}") return None, f"Error: {str(e)}" find_similar_btn.click( analyze_location, inputs=[loc_input, n_points], outputs=[map_output] ) analyze_img_btn.click( analyze_image, inputs=[image_input, image_coords], outputs=[image_display, img_map_output] ) return iface if __name__ == "__main__": iface = create_interface() iface.launch()