davidlsan's picture
Add Streamlit app source and RGB model weights
9d33171 verified
import math
import time
from dataclasses import dataclass
from io import BytesIO
from typing import Any
import requests
import streamlit as st
from PIL import Image
ESRI_TILE_URL = (
"https://server.arcgisonline.com/ArcGIS/rest/services/"
"World_Imagery/MapServer/tile/{z}/{y}/{x}"
)
TILE_SIZE = 256
USER_AGENT = "eurosat-rgb-streamlit-demo/1.0"
EUROSAT_TARGET_MIN_M = 500
EUROSAT_TARGET_MAX_M = 1_000
EUROSAT_ACCEPTABLE_MIN_M = 250
EUROSAT_ACCEPTABLE_MAX_M = 1_500
EUROSAT_MAX_ASPECT_RATIO = 2.0
class TileFetchError(RuntimeError):
"""Raised when an Esri imagery tile cannot be fetched."""
@dataclass(frozen=True)
class BBox:
west: float
south: float
east: float
north: float
@dataclass(frozen=True)
class TileRange:
zoom: int
x_min: int
x_max: int
y_min: int
y_max: int
def extract_bbox_from_geojson(drawing: dict[str, Any]) -> BBox:
"""Extract a lon/lat bbox from a Folium Draw GeoJSON rectangle."""
geometry = drawing.get("geometry", {})
coordinates = geometry.get("coordinates")
if geometry.get("type") != "Polygon" or not coordinates:
raise ValueError("Expected a drawn rectangle polygon.")
ring = coordinates[0]
lons = [point[0] for point in ring]
lats = [point[1] for point in ring]
west, east = min(lons), max(lons)
south, north = min(lats), max(lats)
if west == east or south == north:
raise ValueError("The drawn rectangle has no area.")
return BBox(west=west, south=south, east=east, north=north)
def lonlat_to_tile_fraction(lon: float, lat: float, zoom: int) -> tuple[float, float]:
"""Convert lon/lat to fractional XYZ tile coordinates.
Uses the OpenStreetMap slippy-map convention:
https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
XYZ y coordinates start at 0 at the northern edge of the world.
"""
lat = max(min(lat, 85.05112878), -85.05112878)
lat_rad = math.radians(lat)
n = 2**zoom
x = (lon + 180.0) / 360.0 * n
y = (
1.0
- math.log(math.tan(lat_rad) + (1.0 / math.cos(lat_rad))) / math.pi
) / 2.0 * n
return x, y
def bbox_to_tile_range(bbox: BBox, zoom: int) -> TileRange:
"""Return the inclusive XYZ tile range covering a lon/lat bbox."""
max_tile = (2**zoom) - 1
x_west, y_north = lonlat_to_tile_fraction(bbox.west, bbox.north, zoom)
x_east, y_south = lonlat_to_tile_fraction(bbox.east, bbox.south, zoom)
x_min = max(0, min(max_tile, math.floor(x_west)))
x_max = max(0, min(max_tile, math.floor(x_east)))
y_min = max(0, min(max_tile, math.floor(y_north)))
y_max = max(0, min(max_tile, math.floor(y_south)))
return TileRange(
zoom=zoom,
x_min=min(x_min, x_max),
x_max=max(x_min, x_max),
y_min=min(y_min, y_max),
y_max=max(y_min, y_max),
)
def choose_zoom_level(bbox: BBox) -> int:
"""Choose a tile zoom; EuroSAT-scale rectangles use zoom 14-15."""
width_m, height_m = bbox_size_meters(bbox)
max_side_m = max(width_m, height_m)
if max_side_m <= 1_000:
return 15
if max_side_m <= 5_000:
return 14
return 13
def bbox_size_meters(bbox: BBox) -> tuple[float, float]:
"""Approximate bbox width and height in meters."""
mid_lat = (bbox.north + bbox.south) / 2.0
width_m = _haversine_meters(bbox.west, mid_lat, bbox.east, mid_lat)
height_m = _haversine_meters(bbox.west, bbox.south, bbox.west, bbox.north)
return width_m, height_m
def size_warning_for_bbox(bbox: BBox) -> str | None:
"""Return a user-facing warning for rectangles outside the demo range."""
width_m, height_m = bbox_size_meters(bbox)
min_side_m = min(width_m, height_m)
max_side_m = max(width_m, height_m)
if min_side_m < 50:
return "This rectangle is very small. Draw at least about 50m on a side."
if max_side_m > 5_000:
return "This rectangle is very large. Draw at most about 5km on a side."
return None
def bbox_scale_status(bbox: BBox) -> tuple[str, str]:
"""Classify whether a bbox is close enough to EuroSAT-RGB tile scale."""
width_m, height_m = bbox_size_meters(bbox)
min_side_m = min(width_m, height_m)
max_side_m = max(width_m, height_m)
aspect_ratio = max_side_m / min_side_m
if min_side_m < EUROSAT_ACCEPTABLE_MIN_M:
return (
"invalid",
"This rectangle is too small for a useful EuroSAT-style prediction. "
"Draw closer to 500m-1km on each side.",
)
if max_side_m > EUROSAT_ACCEPTABLE_MAX_M:
return (
"invalid",
"This rectangle is too large for this EuroSAT-style demo. "
"Zoom in and draw closer to 500m-1km on each side.",
)
if aspect_ratio > EUROSAT_MAX_ASPECT_RATIO:
return (
"invalid",
"This rectangle is too stretched. Draw a more square region, like the original EuroSAT tiles.",
)
if (
EUROSAT_TARGET_MIN_M <= min_side_m
and max_side_m <= EUROSAT_TARGET_MAX_M
):
return (
"good",
"Great scale: this is close to the original EuroSAT-RGB tile footprint.",
)
return (
"usable",
"Usable, but not ideal. For the most trustworthy demo result, draw 500m-1km on each side.",
)
def fetch_bbox_image(bbox: BBox, zoom: int | None = None) -> Image.Image:
"""Fetch Esri XYZ tiles for a bbox, stitch them, and crop to the bbox."""
zoom = choose_zoom_level(bbox) if zoom is None else zoom
tile_range = bbox_to_tile_range(bbox, zoom)
stitched = Image.new(
"RGB",
(
(tile_range.x_max - tile_range.x_min + 1) * TILE_SIZE,
(tile_range.y_max - tile_range.y_min + 1) * TILE_SIZE,
),
)
for x in range(tile_range.x_min, tile_range.x_max + 1):
for y in range(tile_range.y_min, tile_range.y_max + 1):
tile = fetch_esri_tile(zoom, x, y)
stitched.paste(
tile,
(
(x - tile_range.x_min) * TILE_SIZE,
(y - tile_range.y_min) * TILE_SIZE,
),
)
time.sleep(0.05)
crop_box = _bbox_crop_box(bbox, tile_range, stitched.size)
cropped = stitched.crop(crop_box)
if cropped.width <= 0 or cropped.height <= 0:
raise TileFetchError("The fetched imagery crop was empty.")
return cropped
@st.cache_data(show_spinner=False)
def fetch_esri_tile(zoom: int, x: int, y: int) -> Image.Image:
"""Download one Esri World Imagery XYZ tile."""
url = ESRI_TILE_URL.format(z=zoom, x=x, y=y)
try:
response = requests.get(
url,
headers={"User-Agent": USER_AGENT},
timeout=10,
)
response.raise_for_status()
except requests.RequestException as exc:
raise TileFetchError(f"Could not download imagery tile z{zoom}/{x}/{y}.") from exc
try:
return Image.open(BytesIO(response.content)).convert("RGB")
except OSError as exc:
raise TileFetchError(f"Downloaded imagery tile z{zoom}/{x}/{y} was invalid.") from exc
def _bbox_crop_box(
bbox: BBox, tile_range: TileRange, stitched_size: tuple[int, int]
) -> tuple[int, int, int, int]:
zoom = tile_range.zoom
west_px, north_px = _lonlat_to_global_pixel(bbox.west, bbox.north, zoom)
east_px, south_px = _lonlat_to_global_pixel(bbox.east, bbox.south, zoom)
origin_x = tile_range.x_min * TILE_SIZE
origin_y = tile_range.y_min * TILE_SIZE
left = math.floor(west_px - origin_x)
top = math.floor(north_px - origin_y)
right = math.ceil(east_px - origin_x)
bottom = math.ceil(south_px - origin_y)
width, height = stitched_size
return (
max(0, min(width, left)),
max(0, min(height, top)),
max(0, min(width, right)),
max(0, min(height, bottom)),
)
def _lonlat_to_global_pixel(lon: float, lat: float, zoom: int) -> tuple[float, float]:
x_tile, y_tile = lonlat_to_tile_fraction(lon, lat, zoom)
return x_tile * TILE_SIZE, y_tile * TILE_SIZE
def _haversine_meters(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
radius_m = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2.0) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2
)
return 2.0 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1.0 - a))