| |
|
| | |
| | """ |
| | Compute Embeddings for Major-TOM Sentinel-2 Images |
| | |
| | This script generates embeddings for Sentinel-2 imagery using various models: |
| | - DINOv2: Vision Transformer trained with self-supervised learning |
| | - SigLIP: Vision-Language model with sigmoid loss |
| | - FarSLIP: Remote sensing fine-tuned CLIP |
| | - SatCLIP: Satellite imagery CLIP with location awareness |
| | |
| | Usage: |
| | python compute_embeddings.py --model dinov2 --device cuda:1 |
| | python compute_embeddings.py --model siglip --device cuda:5 |
| | python compute_embeddings.py --model satclip --device cuda:3 |
| | python compute_embeddings.py --model farslip --device cuda:4 |
| | |
| | Author: Generated by Copilot |
| | """ |
| |
|
| | import os |
| | import sys |
| | import argparse |
| | import logging |
| | from pathlib import Path |
| | from datetime import datetime |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | from PIL import Image |
| | from tqdm.auto import tqdm |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).parent.absolute() |
| | if str(PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| |
|
| | from models.load_config import load_and_process_config |
| |
|
| |
|
| | |
| | |
| | |
| | METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet") |
| | IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images") |
| | OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k") |
| |
|
| | |
| | COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype'] |
| |
|
| | |
| | COLUMNS_RENAME = {'crs': 'utm_crs'} |
| |
|
| | |
| | |
| | PIXEL_BBOX = [342, 342, 726, 726] |
| |
|
| | |
| | MODEL_OUTPUT_PATHS = { |
| | 'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet', |
| | 'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet', |
| | 'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet', |
| | 'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet', |
| | } |
| |
|
| | |
| | BATCH_SIZES = { |
| | 'dinov2': 64, |
| | 'siglip': 64, |
| | 'farslip': 64, |
| | 'satclip': 128, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | def setup_logging(model_name: str): |
| | """Configure logging to both file and console.""" |
| | log_dir = PROJECT_ROOT / "logs" |
| | log_dir.mkdir(parents=True, exist_ok=True) |
| | log_file = log_dir / f"compute_embeddings_{model_name}.log" |
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s [%(levelname)s] %(message)s", |
| | handlers=[ |
| | logging.FileHandler(log_file), |
| | logging.StreamHandler(sys.stdout) |
| | ] |
| | ) |
| | return logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| | def decode_image_bytes(row) -> np.ndarray: |
| | """ |
| | Decode image bytes from parquet row to numpy array. |
| | |
| | Args: |
| | row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype' |
| | |
| | Returns: |
| | np.ndarray of shape (H, W, 12) with uint16 values |
| | """ |
| | shape = tuple(map(int, row['image_shape'])) |
| | dtype = np.dtype(row['image_dtype']) |
| | img_flat = np.frombuffer(row['image_bytes'], dtype=dtype) |
| | return img_flat.reshape(shape) |
| |
|
| |
|
| | def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image: |
| | """ |
| | Extract RGB channels from 12-band Sentinel-2 array. |
| | |
| | Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] |
| | RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1) |
| | |
| | Args: |
| | img_array: numpy array of shape (H, W, 12) |
| | clip_max: Value to clip reflectance data for visualization |
| | |
| | Returns: |
| | PIL.Image: RGB image |
| | """ |
| | |
| | rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32) |
| | |
| | |
| | rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1) |
| | |
| | |
| | rgb_uint8 = (rgb_normalized * 255).astype(np.uint8) |
| | |
| | return Image.fromarray(rgb_uint8) |
| |
|
| |
|
| | |
| | |
| | |
| | def load_model(model_name: str, device: str, config: dict): |
| | """ |
| | Load the specified model. |
| | |
| | Args: |
| | model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip' |
| | device: Device string like 'cuda:0' or 'cpu' |
| | config: Configuration dictionary from local.yaml |
| | |
| | Returns: |
| | Model instance |
| | """ |
| | logger = logging.getLogger(__name__) |
| | |
| | if model_name == 'dinov2': |
| | from models.dinov2_model import DINOv2Model |
| | model_config = config.get('dinov2', {}) |
| | model = DINOv2Model( |
| | ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'), |
| | model_name='facebook/dinov2-large', |
| | embedding_path=None, |
| | device=device |
| | ) |
| | logger.info(f"DINOv2 model loaded on {device}") |
| | return model |
| | |
| | elif model_name == 'siglip': |
| | from models.siglip_model import SigLIPModel |
| | model_config = config.get('siglip', {}) |
| | model = SigLIPModel( |
| | ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'), |
| | model_name='ViT-SO400M-14-SigLIP-384', |
| | tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'), |
| | embedding_path=None, |
| | device=device |
| | ) |
| | |
| | model.df_embed = None |
| | model.image_embeddings = None |
| | logger.info(f"SigLIP model loaded on {device}") |
| | return model |
| | |
| | elif model_name == 'farslip': |
| | from models.farslip_model import FarSLIPModel |
| | model_config = config.get('farslip', {}) |
| | model = FarSLIPModel( |
| | ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'), |
| | model_name='ViT-B-16', |
| | embedding_path=None, |
| | device=device |
| | ) |
| | logger.info(f"FarSLIP model loaded on {device}") |
| | return model |
| | |
| | elif model_name == 'satclip': |
| | from models.satclip_ms_model import SatCLIPMSModel |
| | model_config = config.get('satclip', {}) |
| | model = SatCLIPMSModel( |
| | ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'), |
| | embedding_path=None, |
| | device=device |
| | ) |
| | logger.info(f"SatCLIP-MS model loaded on {device}") |
| | return model |
| | |
| | else: |
| | raise ValueError(f"Unknown model: {model_name}") |
| |
|
| |
|
| | |
| | |
| | |
| | def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray: |
| | """ |
| | Compute embedding for a single image. |
| | |
| | Args: |
| | model: Model instance |
| | model_name: Model identifier |
| | img_array: numpy array of shape (H, W, 12) |
| | |
| | Returns: |
| | np.ndarray: 1D embedding vector |
| | """ |
| | if model_name in ['dinov2', 'siglip', 'farslip']: |
| | |
| | rgb_img = extract_rgb_image(img_array) |
| | feature = model.encode_image(rgb_img) |
| | if feature is not None: |
| | return feature.cpu().numpy().flatten() |
| | return None |
| | |
| | elif model_name == 'satclip': |
| | |
| | feature = model.encode_image(img_array, is_multispectral=True) |
| | if feature is not None: |
| | return feature.cpu().numpy().flatten() |
| | return None |
| | |
| | return None |
| |
|
| |
|
| | def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list: |
| | """ |
| | Compute embeddings for a batch of images. |
| | Falls back to single-image processing if batch method unavailable. |
| | |
| | Args: |
| | model: Model instance |
| | model_name: Model identifier |
| | img_arrays: List of numpy arrays of shape (H, W, 12) |
| | |
| | Returns: |
| | List of 1D embedding vectors (numpy arrays), None for failed items |
| | """ |
| | n_images = len(img_arrays) |
| | |
| | if model_name in ['dinov2', 'siglip', 'farslip']: |
| | |
| | rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays] |
| | |
| | |
| | if hasattr(model, 'encode_images'): |
| | try: |
| | features = model.encode_images(rgb_imgs) |
| | if features is not None: |
| | return [features[i].cpu().numpy().flatten() for i in range(len(features))] |
| | except Exception: |
| | pass |
| | |
| | |
| | results = [] |
| | for img in rgb_imgs: |
| | try: |
| | feature = model.encode_image(img) |
| | if feature is not None: |
| | results.append(feature.cpu().numpy().flatten()) |
| | else: |
| | results.append(None) |
| | except Exception: |
| | results.append(None) |
| | return results |
| | |
| | elif model_name == 'satclip': |
| | |
| | if hasattr(model, 'encode_images'): |
| | try: |
| | features = model.encode_images(img_arrays, is_multispectral=True) |
| | if features is not None: |
| | return [features[i].cpu().numpy().flatten() for i in range(len(features))] |
| | except Exception: |
| | pass |
| | |
| | |
| | results = [] |
| | for arr in img_arrays: |
| | try: |
| | feature = model.encode_image(arr, is_multispectral=True) |
| | if feature is not None: |
| | results.append(feature.cpu().numpy().flatten()) |
| | else: |
| | results.append(None) |
| | except Exception: |
| | results.append(None) |
| | return results |
| | |
| | return [None] * n_images |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def process_parquet_file( |
| | file_path: Path, |
| | model, |
| | model_name: str, |
| | batch_size: int = 64 |
| | ) -> pd.DataFrame: |
| | """ |
| | Process a single parquet file and generate embeddings using batch processing. |
| | |
| | Args: |
| | file_path: Path to input parquet file |
| | model: Model instance |
| | model_name: Model identifier |
| | batch_size: Batch size for processing |
| | |
| | Returns: |
| | DataFrame with embeddings |
| | """ |
| | logger = logging.getLogger(__name__) |
| | |
| | |
| | df = pd.read_parquet(file_path) |
| | n_rows = len(df) |
| | |
| | embeddings_list = [None] * n_rows |
| | valid_mask = [False] * n_rows |
| | |
| | |
| | for batch_start in range(0, n_rows, batch_size): |
| | batch_end = min(batch_start + batch_size, n_rows) |
| | batch_indices = list(range(batch_start, batch_end)) |
| | |
| | |
| | batch_arrays = [] |
| | batch_valid_indices = [] |
| | |
| | for idx in batch_indices: |
| | try: |
| | row = df.iloc[idx] |
| | img_array = decode_image_bytes(row) |
| | batch_arrays.append(img_array) |
| | batch_valid_indices.append(idx) |
| | except Exception as e: |
| | logger.warning(f"Error decoding row {idx}: {e}") |
| | continue |
| | |
| | if not batch_arrays: |
| | continue |
| | |
| | |
| | try: |
| | batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays) |
| | |
| | |
| | for i, idx in enumerate(batch_valid_indices): |
| | if batch_embeddings[i] is not None: |
| | embeddings_list[idx] = batch_embeddings[i] |
| | valid_mask[idx] = True |
| | |
| | except Exception as e: |
| | logger.warning(f"Error computing batch embeddings: {e}") |
| | |
| | for i, idx in enumerate(batch_valid_indices): |
| | try: |
| | embedding = compute_embedding_single(model, model_name, batch_arrays[i]) |
| | if embedding is not None: |
| | embeddings_list[idx] = embedding |
| | valid_mask[idx] = True |
| | except Exception as inner_e: |
| | logger.warning(f"Error processing row {idx}: {inner_e}") |
| | continue |
| | |
| | |
| | valid_indices = [i for i, v in enumerate(valid_mask) if v] |
| | |
| | if not valid_indices: |
| | logger.warning(f"No valid embeddings for {file_path.name}") |
| | return None |
| | |
| | |
| | result_df = df.iloc[valid_indices].copy() |
| | valid_embeddings = [embeddings_list[i] for i in valid_indices] |
| | |
| | |
| | cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] |
| | if cols_to_drop: |
| | result_df = result_df.drop(columns=cols_to_drop) |
| | |
| | |
| | if 'image_bytes' in result_df.columns: |
| | result_df = result_df.drop(columns=['image_bytes']) |
| | |
| | |
| | if 'geometry' in result_df.columns: |
| | result_df = result_df.drop(columns=['geometry']) |
| | |
| | |
| | result_df = result_df.rename(columns=COLUMNS_RENAME) |
| | |
| | |
| | result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) |
| | |
| | |
| | result_df['embedding'] = valid_embeddings |
| | |
| | return result_df |
| |
|
| | |
| | |
| | |
| | def main(): |
| | parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images') |
| | parser.add_argument('--model', type=str, required=True, |
| | choices=['dinov2', 'siglip', 'farslip', 'satclip'], |
| | help='Model to use for embedding computation') |
| | parser.add_argument('--device', type=str, default='cuda:0', |
| | help='Device to run on (e.g., cuda:0, cuda:1, cpu)') |
| | parser.add_argument('--batch-size', type=int, default=None, |
| | help='Batch size for processing (default: model-specific)') |
| | parser.add_argument('--max-files', type=int, default=None, |
| | help='Maximum number of files to process (for testing)') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | logger = setup_logging(args.model) |
| | |
| | logger.info("=" * 80) |
| | logger.info(f"Computing {args.model.upper()} embeddings") |
| | logger.info(f"Timestamp: {datetime.now().isoformat()}") |
| | logger.info(f"Device: {args.device}") |
| | logger.info("=" * 80) |
| | |
| | |
| | config = load_and_process_config() |
| | if config is None: |
| | logger.warning("No config file found, using default paths") |
| | config = {} |
| | |
| | |
| | batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64) |
| | logger.info(f"Batch size: {batch_size}") |
| | |
| | |
| | output_path = MODEL_OUTPUT_PATHS[args.model] |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | logger.info(f"Output path: {output_path}") |
| | |
| | |
| | logger.info(f"Loading {args.model} model...") |
| | model = load_model(args.model, args.device, config) |
| | |
| | |
| | parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet")) |
| | if args.max_files: |
| | parquet_files = parquet_files[:args.max_files] |
| | |
| | logger.info(f"Found {len(parquet_files)} input files") |
| | |
| | |
| | all_results = [] |
| | total_rows = 0 |
| | |
| | for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"): |
| | try: |
| | result_df = process_parquet_file(file_path, model, args.model, batch_size) |
| | |
| | if result_df is not None: |
| | all_results.append(result_df) |
| | total_rows += len(result_df) |
| | logger.info(f"[{file_path.name}] Processed {len(result_df)} rows") |
| | |
| | except Exception as e: |
| | logger.error(f"Error processing {file_path.name}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | continue |
| | |
| | |
| | if all_results: |
| | logger.info("Merging all results...") |
| | final_df = pd.concat(all_results, ignore_index=True) |
| | |
| | |
| | logger.info(f"Final columns: {list(final_df.columns)}") |
| | |
| | |
| | removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns] |
| | if removed: |
| | logger.warning(f"Columns still present that should be removed: {removed}") |
| | else: |
| | logger.info("✓ All unwanted columns removed") |
| | |
| | |
| | if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns: |
| | logger.info("✓ Column 'crs' renamed to 'utm_crs'") |
| | |
| | |
| | if 'pixel_bbox' in final_df.columns: |
| | logger.info("✓ Column 'pixel_bbox' added") |
| | |
| | |
| | logger.info(f"Saving to {output_path}...") |
| | final_df.to_parquet(output_path, index=False) |
| | |
| | logger.info(f"=" * 80) |
| | logger.info(f"Processing complete!") |
| | logger.info(f"Total rows: {len(final_df):,}") |
| | logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}") |
| | logger.info(f"Output file: {output_path}") |
| | logger.info(f"=" * 80) |
| | |
| | else: |
| | logger.error("No data processed!") |
| | return 1 |
| | |
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | sys.exit(main()) |
| |
|