| import math |
| import time |
| from dataclasses import dataclass |
| from io import BytesIO |
| from typing import Any |
|
|
| import requests |
| import streamlit as st |
| from PIL import Image |
|
|
|
|
| ESRI_TILE_URL = ( |
| "https://server.arcgisonline.com/ArcGIS/rest/services/" |
| "World_Imagery/MapServer/tile/{z}/{y}/{x}" |
| ) |
| TILE_SIZE = 256 |
| USER_AGENT = "eurosat-rgb-streamlit-demo/1.0" |
| EUROSAT_TARGET_MIN_M = 500 |
| EUROSAT_TARGET_MAX_M = 1_000 |
| EUROSAT_ACCEPTABLE_MIN_M = 250 |
| EUROSAT_ACCEPTABLE_MAX_M = 1_500 |
| EUROSAT_MAX_ASPECT_RATIO = 2.0 |
|
|
|
|
| class TileFetchError(RuntimeError): |
| """Raised when an Esri imagery tile cannot be fetched.""" |
|
|
|
|
| @dataclass(frozen=True) |
| class BBox: |
| west: float |
| south: float |
| east: float |
| north: float |
|
|
|
|
| @dataclass(frozen=True) |
| class TileRange: |
| zoom: int |
| x_min: int |
| x_max: int |
| y_min: int |
| y_max: int |
|
|
|
|
| def extract_bbox_from_geojson(drawing: dict[str, Any]) -> BBox: |
| """Extract a lon/lat bbox from a Folium Draw GeoJSON rectangle.""" |
| geometry = drawing.get("geometry", {}) |
| coordinates = geometry.get("coordinates") |
| if geometry.get("type") != "Polygon" or not coordinates: |
| raise ValueError("Expected a drawn rectangle polygon.") |
|
|
| ring = coordinates[0] |
| lons = [point[0] for point in ring] |
| lats = [point[1] for point in ring] |
| west, east = min(lons), max(lons) |
| south, north = min(lats), max(lats) |
| if west == east or south == north: |
| raise ValueError("The drawn rectangle has no area.") |
|
|
| return BBox(west=west, south=south, east=east, north=north) |
|
|
|
|
| def lonlat_to_tile_fraction(lon: float, lat: float, zoom: int) -> tuple[float, float]: |
| """Convert lon/lat to fractional XYZ tile coordinates. |
| |
| Uses the OpenStreetMap slippy-map convention: |
| https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames |
| XYZ y coordinates start at 0 at the northern edge of the world. |
| """ |
| lat = max(min(lat, 85.05112878), -85.05112878) |
| lat_rad = math.radians(lat) |
| n = 2**zoom |
| x = (lon + 180.0) / 360.0 * n |
| y = ( |
| 1.0 |
| - math.log(math.tan(lat_rad) + (1.0 / math.cos(lat_rad))) / math.pi |
| ) / 2.0 * n |
| return x, y |
|
|
|
|
| def bbox_to_tile_range(bbox: BBox, zoom: int) -> TileRange: |
| """Return the inclusive XYZ tile range covering a lon/lat bbox.""" |
| max_tile = (2**zoom) - 1 |
| x_west, y_north = lonlat_to_tile_fraction(bbox.west, bbox.north, zoom) |
| x_east, y_south = lonlat_to_tile_fraction(bbox.east, bbox.south, zoom) |
|
|
| x_min = max(0, min(max_tile, math.floor(x_west))) |
| x_max = max(0, min(max_tile, math.floor(x_east))) |
| y_min = max(0, min(max_tile, math.floor(y_north))) |
| y_max = max(0, min(max_tile, math.floor(y_south))) |
|
|
| return TileRange( |
| zoom=zoom, |
| x_min=min(x_min, x_max), |
| x_max=max(x_min, x_max), |
| y_min=min(y_min, y_max), |
| y_max=max(y_min, y_max), |
| ) |
|
|
|
|
| def choose_zoom_level(bbox: BBox) -> int: |
| """Choose a tile zoom; EuroSAT-scale rectangles use zoom 14-15.""" |
| width_m, height_m = bbox_size_meters(bbox) |
| max_side_m = max(width_m, height_m) |
| if max_side_m <= 1_000: |
| return 15 |
| if max_side_m <= 5_000: |
| return 14 |
| return 13 |
|
|
|
|
| def bbox_size_meters(bbox: BBox) -> tuple[float, float]: |
| """Approximate bbox width and height in meters.""" |
| mid_lat = (bbox.north + bbox.south) / 2.0 |
| width_m = _haversine_meters(bbox.west, mid_lat, bbox.east, mid_lat) |
| height_m = _haversine_meters(bbox.west, bbox.south, bbox.west, bbox.north) |
| return width_m, height_m |
|
|
|
|
| def size_warning_for_bbox(bbox: BBox) -> str | None: |
| """Return a user-facing warning for rectangles outside the demo range.""" |
| width_m, height_m = bbox_size_meters(bbox) |
| min_side_m = min(width_m, height_m) |
| max_side_m = max(width_m, height_m) |
| if min_side_m < 50: |
| return "This rectangle is very small. Draw at least about 50m on a side." |
| if max_side_m > 5_000: |
| return "This rectangle is very large. Draw at most about 5km on a side." |
| return None |
|
|
|
|
| def bbox_scale_status(bbox: BBox) -> tuple[str, str]: |
| """Classify whether a bbox is close enough to EuroSAT-RGB tile scale.""" |
| width_m, height_m = bbox_size_meters(bbox) |
| min_side_m = min(width_m, height_m) |
| max_side_m = max(width_m, height_m) |
| aspect_ratio = max_side_m / min_side_m |
|
|
| if min_side_m < EUROSAT_ACCEPTABLE_MIN_M: |
| return ( |
| "invalid", |
| "This rectangle is too small for a useful EuroSAT-style prediction. " |
| "Draw closer to 500m-1km on each side.", |
| ) |
| if max_side_m > EUROSAT_ACCEPTABLE_MAX_M: |
| return ( |
| "invalid", |
| "This rectangle is too large for this EuroSAT-style demo. " |
| "Zoom in and draw closer to 500m-1km on each side.", |
| ) |
| if aspect_ratio > EUROSAT_MAX_ASPECT_RATIO: |
| return ( |
| "invalid", |
| "This rectangle is too stretched. Draw a more square region, like the original EuroSAT tiles.", |
| ) |
| if ( |
| EUROSAT_TARGET_MIN_M <= min_side_m |
| and max_side_m <= EUROSAT_TARGET_MAX_M |
| ): |
| return ( |
| "good", |
| "Great scale: this is close to the original EuroSAT-RGB tile footprint.", |
| ) |
| return ( |
| "usable", |
| "Usable, but not ideal. For the most trustworthy demo result, draw 500m-1km on each side.", |
| ) |
|
|
|
|
| def fetch_bbox_image(bbox: BBox, zoom: int | None = None) -> Image.Image: |
| """Fetch Esri XYZ tiles for a bbox, stitch them, and crop to the bbox.""" |
| zoom = choose_zoom_level(bbox) if zoom is None else zoom |
| tile_range = bbox_to_tile_range(bbox, zoom) |
|
|
| stitched = Image.new( |
| "RGB", |
| ( |
| (tile_range.x_max - tile_range.x_min + 1) * TILE_SIZE, |
| (tile_range.y_max - tile_range.y_min + 1) * TILE_SIZE, |
| ), |
| ) |
|
|
| for x in range(tile_range.x_min, tile_range.x_max + 1): |
| for y in range(tile_range.y_min, tile_range.y_max + 1): |
| tile = fetch_esri_tile(zoom, x, y) |
| stitched.paste( |
| tile, |
| ( |
| (x - tile_range.x_min) * TILE_SIZE, |
| (y - tile_range.y_min) * TILE_SIZE, |
| ), |
| ) |
| time.sleep(0.05) |
|
|
| crop_box = _bbox_crop_box(bbox, tile_range, stitched.size) |
| cropped = stitched.crop(crop_box) |
| if cropped.width <= 0 or cropped.height <= 0: |
| raise TileFetchError("The fetched imagery crop was empty.") |
| return cropped |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def fetch_esri_tile(zoom: int, x: int, y: int) -> Image.Image: |
| """Download one Esri World Imagery XYZ tile.""" |
| url = ESRI_TILE_URL.format(z=zoom, x=x, y=y) |
| try: |
| response = requests.get( |
| url, |
| headers={"User-Agent": USER_AGENT}, |
| timeout=10, |
| ) |
| response.raise_for_status() |
| except requests.RequestException as exc: |
| raise TileFetchError(f"Could not download imagery tile z{zoom}/{x}/{y}.") from exc |
|
|
| try: |
| return Image.open(BytesIO(response.content)).convert("RGB") |
| except OSError as exc: |
| raise TileFetchError(f"Downloaded imagery tile z{zoom}/{x}/{y} was invalid.") from exc |
|
|
|
|
| def _bbox_crop_box( |
| bbox: BBox, tile_range: TileRange, stitched_size: tuple[int, int] |
| ) -> tuple[int, int, int, int]: |
| zoom = tile_range.zoom |
| west_px, north_px = _lonlat_to_global_pixel(bbox.west, bbox.north, zoom) |
| east_px, south_px = _lonlat_to_global_pixel(bbox.east, bbox.south, zoom) |
| origin_x = tile_range.x_min * TILE_SIZE |
| origin_y = tile_range.y_min * TILE_SIZE |
|
|
| left = math.floor(west_px - origin_x) |
| top = math.floor(north_px - origin_y) |
| right = math.ceil(east_px - origin_x) |
| bottom = math.ceil(south_px - origin_y) |
|
|
| width, height = stitched_size |
| return ( |
| max(0, min(width, left)), |
| max(0, min(height, top)), |
| max(0, min(width, right)), |
| max(0, min(height, bottom)), |
| ) |
|
|
|
|
| def _lonlat_to_global_pixel(lon: float, lat: float, zoom: int) -> tuple[float, float]: |
| x_tile, y_tile = lonlat_to_tile_fraction(lon, lat, zoom) |
| return x_tile * TILE_SIZE, y_tile * TILE_SIZE |
|
|
|
|
| def _haversine_meters(lon1: float, lat1: float, lon2: float, lat2: float) -> float: |
| radius_m = 6_371_000 |
| phi1 = math.radians(lat1) |
| phi2 = math.radians(lat2) |
| delta_phi = math.radians(lat2 - lat1) |
| delta_lambda = math.radians(lon2 - lon1) |
| a = ( |
| math.sin(delta_phi / 2.0) ** 2 |
| + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2 |
| ) |
| return 2.0 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1.0 - a)) |
|
|