Spaces:
Running
Running
| import fsspec | |
| import pyarrow.parquet as pq | |
| import numpy as np | |
| from PIL import Image | |
| from io import BytesIO | |
| from rasterio.io import MemoryFile | |
| import matplotlib.pyplot as plt | |
| import cartopy.crs as ccrs | |
| import cartopy.io.img_tiles as cimgt | |
| from matplotlib.patches import Rectangle | |
| import math | |
| from matplotlib.figure import Figure | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| def crop_center(img_array, cropx, cropy): | |
| y, x, c = img_array.shape | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| return img_array[starty:starty+cropy, startx:startx+cropx] | |
| def read_tif_bytes(tif_bytes): | |
| with MemoryFile(tif_bytes) as mem_f: | |
| with mem_f.open(driver='GTiff') as f: | |
| return f.read().squeeze() | |
| def read_row_memory(row_dict, columns=["thumbnail"]): | |
| url = row_dict['parquet_url'] | |
| row_idx = row_dict['parquet_row'] | |
| fs_options = { | |
| "cache_type": "readahead", | |
| "block_size": 5 * 1024 * 1024 | |
| } | |
| with fsspec.open(url, mode='rb', **fs_options) as f: | |
| with pq.ParquetFile(f) as pf: | |
| table = pf.read_row_group(row_idx, columns=columns) | |
| row_output = {} | |
| for col in columns: | |
| col_data = table[col][0].as_py() | |
| if col != 'thumbnail': | |
| row_output[col] = read_tif_bytes(col_data) | |
| else: | |
| stream = BytesIO(col_data) | |
| row_output[col] = Image.open(stream) | |
| return row_output | |
| def download_and_process_image(product_id, df_source=None, verbose=True): | |
| if df_source is None: | |
| if verbose: print("❌ Error: No DataFrame provided.") | |
| return None, None | |
| row_subset = df_source[df_source['product_id'] == product_id] | |
| if len(row_subset) == 0: | |
| if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.") | |
| return None, None | |
| row_dict = row_subset.iloc[0].to_dict() | |
| if 'parquet_url' in row_dict: | |
| url = row_dict['parquet_url'] | |
| if 'huggingface.co' in url: | |
| row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master') | |
| elif 'hf-mirror.com' in url: | |
| row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master') | |
| else: | |
| if verbose: print("❌ Error: 'parquet_url' missing in metadata.") | |
| return None, None | |
| if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...") | |
| try: | |
| bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02']) | |
| if not all(b in bands_data for b in ['B04', 'B03', 'B02']): | |
| if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}") | |
| return None, None | |
| rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1) | |
| if verbose: | |
| print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}") | |
| # Check if data is already 0-255 or 0-1 | |
| if rgb_img.max() <= 255: | |
| # Assume it might be uint8 or scaled | |
| pass | |
| rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1) | |
| rgb_uint8 = (rgb_norm * 255).astype(np.uint8) | |
| if verbose: | |
| print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}") | |
| img_full = Image.fromarray(rgb_uint8) | |
| if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384: | |
| cropped_array = crop_center(rgb_uint8, 384, 384) | |
| img_384 = Image.fromarray(cropped_array) | |
| else: | |
| if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.") | |
| img_384 = img_full.resize((384, 384)) | |
| if verbose: print(f"✅ Successfully processed {product_id}") | |
| return img_384, img_full | |
| except Exception as e: | |
| if verbose: print(f"❌ Error processing {product_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None | |
| # Define Esri Imagery Class | |
| class EsriImagery(cimgt.GoogleTiles): | |
| def _image_url(self, tile): | |
| x, y, z = tile | |
| return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}' | |
| from PIL import Image, ImageDraw, ImageFont | |
| def get_placeholder_image(text="Image Unavailable", size=(384, 384)): | |
| img = Image.new('RGB', size, color=(200, 200, 200)) | |
| d = ImageDraw.Draw(img) | |
| try: | |
| # Try to load a default font | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| # Draw text in center (rough approximation) | |
| # For better centering we would need font metrics, but simple is fine here | |
| d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font) | |
| return img | |
| def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None): | |
| """ | |
| Generates a satellite image visualization using Esri World Imagery via Cartopy. | |
| Matches the style of the provided notebook. | |
| Uses OO Matplotlib API for thread safety. | |
| """ | |
| try: | |
| imagery = EsriImagery() | |
| # Create figure using OO API | |
| fig = Figure(figsize=(5, 5), dpi=100) | |
| canvas = FigureCanvasAgg(fig) | |
| ax = fig.add_subplot(1, 1, 1, projection=imagery.crs) | |
| # Set extent to approx 10km x 10km around the point | |
| extent_deg = 0.05 | |
| ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree()) | |
| # Add the imagery | |
| ax.add_image(imagery, 14) | |
| # Add a marker for the center | |
| ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree()) | |
| # Add Bounding Box (3840m x 3840m) | |
| box_size_m = 384 * 10 # 3840m | |
| # Convert meters to degrees (approx) | |
| # 1 deg lat = 111320m | |
| # 1 deg lon = 111320m * cos(lat) | |
| dlat = (box_size_m / 111320) | |
| dlon = (box_size_m / (111320 * math.cos(math.radians(lat)))) | |
| # Bottom-Left corner | |
| rect_lon = lon - dlon / 2 | |
| rect_lat = lat - dlat / 2 | |
| # Add Rectangle | |
| rect = Rectangle((rect_lon, rect_lat), dlon, dlat, | |
| linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree()) | |
| ax.add_patch(rect) | |
| # Title | |
| title_parts = [] | |
| if query: title_parts.append(f"{query}") | |
| if rank is not None: title_parts.append(f"Rank {rank}") | |
| if score is not None: title_parts.append(f"Score: {score:.4f}") | |
| ax.set_title("\n".join(title_parts), fontsize=10) | |
| # Save to buffer | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| return Image.open(buf) | |
| except Exception as e: | |
| # Suppress full traceback for network errors to avoid log spam | |
| error_msg = str(e) | |
| if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg: | |
| print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.") | |
| else: | |
| print(f"Error generating Esri image for {lat}, {lon}: {e}") | |
| # Only print traceback for non-network errors | |
| # import traceback | |
| # traceback.print_exc() | |
| # Return a placeholder image with text | |
| return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})") | |
| def get_esri_satellite_image_url(lat, lon, zoom=14): | |
| """ | |
| Returns the URL for the Esri World Imagery tile at the given location. | |
| """ | |
| try: | |
| imagery = EsriImagery() | |
| # Calculate tile coordinates | |
| # This is a simplification, cimgt handles this internally usually | |
| # But for direct URL we might need more logic or just use the static map approach above | |
| # For now, let's stick to the static map generation which works | |
| pass | |
| except: | |
| pass | |
| return None | |