| """ |
| Sample embeddings + reconstructions from a trained AstroPT model (AIM-compatible) |
| |
| DETERMINISTIC / NO-SHUFFLE VERSION: |
| - Pulls from HuggingFace "Smith42/galaxies" using streaming=True |
| - DOES NOT shuffle at all |
| - Iterates in the exact order HuggingFace yields examples (shard/file order) |
| - Generates 12 reconstructions FIRST using the first 12 examples of `recon_split` |
| - Then extracts embeddings from BOTH `test` + `validation` in order and saves .npy |
| - Saves IDs from HF column `dr8_id` in a separate file: |
| - idxs_...npy (dtype=str/object), aligned 1:1 with embeddings rows |
| - Avoids concatenate_datasets() for streaming datasets (prevents PyArrow crash with torch tensors) |
| |
| Notes: |
| - If you want the strictest “exact order” behavior, set num_workers = 0. |
| - The recon figure uses validate()-style rendering (prepend zero token + optional antispiralise). |
| """ |
|
|
| import os |
| import math |
| import functools |
| from contextlib import nullcontext |
| from typing import Dict, Any, List, Optional |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| import matplotlib.pyplot as plt |
| import einops |
| from tqdm import tqdm |
|
|
| from datasets import load_dataset |
|
|
| from astropt.model import GPT, GPTConfig |
| from astropt.local_datasets import GalaxyImageDataset |
|
|
|
|
| |
| |
| |
| |
| out_dir = "/mnt/c/Users/shaha/Downloads/AIM" |
|
|
| |
| device = "cuda" |
| dtype = "bfloat16" |
| compile = False |
|
|
| |
| dataset_name = "Smith42/galaxies" |
| stream_hf_dataset = True |
| revision = None |
|
|
| |
| splits_for_embeddings = ["test", "validation"] |
| batch_size = 256 |
| num_workers = 0 |
| pin_memory = True |
|
|
| prefix_len = 64 |
| embed_reduction = "mean" |
|
|
| |
| n_recon = 12 |
| recon_split = "test" |
| save_recon_name = "recon_12.png" |
|
|
| |
| id_field_name = "dr8_id" |
|
|
|
|
| |
| |
| |
| def normalise(x: torch.Tensor, use_hf: bool = False) -> torch.Tensor: |
| |
| if use_hf and isinstance(x, np.ndarray): |
| x = torch.from_numpy(x).to(torch.float32) |
| std, mean = torch.std_mean(x, dim=1, keepdim=True) |
| x_norm = (x - mean) / (std + 1e-8) |
| return x_norm.to(torch.float16) |
|
|
|
|
| def data_transforms(use_hf: bool): |
| return transforms.Compose( |
| [ |
| transforms.Lambda(functools.partial(normalise, use_hf=use_hf)), |
| ] |
| ) |
|
|
|
|
| def process_galaxy_wrapper(galdict: Dict[str, Any], func): |
| """Wrapper for processing galaxy images from HF dataset (MATCH TRAINING).""" |
| patch_galaxy = func(np.array(galdict["image"]).swapaxes(0, 2)) |
| return { |
| "images": patch_galaxy.to(torch.float), |
| "images_positions": torch.arange(0, len(patch_galaxy), dtype=torch.long), |
| |
| id_field_name: galdict.get(id_field_name, "-1"), |
| } |
|
|
|
|
| |
| |
| |
| def load_model(out_dir: str, device: str, dtype: str, compile_model: bool): |
| ckpt_path = os.path.join(out_dir, "ckpt.pt") |
| if not os.path.exists(ckpt_path): |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
|
|
| checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| modality_registry = checkpoint["modality_registry"] |
| gptconf = GPTConfig(**checkpoint["model_args"]) |
| model = GPT(gptconf, modality_registry) |
|
|
| state_dict = checkpoint["model"] |
| unwanted_prefix = "_orig_mod." |
| for k, v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) |
|
|
| model.load_state_dict(state_dict) |
| model.eval().to(device) |
|
|
| if compile_model: |
| model = torch.compile(model) |
|
|
| device_type = "cuda" if "cuda" in device else "cpu" |
| ptdtype = { |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| "float16": torch.float16, |
| }[dtype] |
| ctx = ( |
| nullcontext() |
| if device_type == "cpu" |
| else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
| ) |
|
|
| return model, modality_registry, checkpoint, ctx |
|
|
|
|
| |
| |
| |
| def build_hf_stream( |
| dataset_name: str, |
| split: str, |
| galproc: GalaxyImageDataset, |
| streaming: bool, |
| revision: Optional[str] = None, |
| ): |
| kwargs = dict(split=split, streaming=streaming) |
| if revision is not None: |
| kwargs["revision"] = revision |
|
|
| ds = load_dataset(dataset_name, **kwargs) |
|
|
| |
| |
| |
| |
| try: |
| ds = ds.select_columns(["image_crop", id_field_name]) |
| except Exception: |
| |
| ds = ds.select_columns(["image_crop"]) |
|
|
| ds = ds.rename_column("image_crop", "image") |
|
|
| ds = ds.map(functools.partial(process_galaxy_wrapper, func=galproc.process_galaxy)) |
|
|
| |
| |
| ds = ds.remove_columns("image") |
|
|
| return ds |
|
|
|
|
| |
| |
| |
| def tokens_to_images_validate_style( |
| tokens: torch.Tensor, |
| galproc: GalaxyImageDataset, |
| spiral: bool, |
| patch_size: int, |
| image_size: int, |
| ) -> np.ndarray: |
| """ |
| Mirrors validate() behavior: |
| - prepend a zero_block at t=0 |
| - if spiral: antispiralise per sample |
| - rearrange patches into [B, H, W, C] |
| """ |
| B, T, D = tokens.shape |
|
|
| zero_block = torch.zeros((B, 1, D), device=tokens.device, dtype=tokens.dtype) |
| tok = torch.cat((zero_block, tokens), dim=1) |
|
|
| if spiral: |
| tok = torch.stack([galproc.antispiralise(yy) for yy in tok]) |
|
|
| |
| tok = tok[:, 1:, :] |
|
|
| n_chan = D // (patch_size * patch_size) |
| if (patch_size * patch_size * n_chan) != D: |
| raise RuntimeError( |
| f"Cannot factor patch_dim={D} into patch_size^2 * n_chan (patch_size={patch_size})." |
| ) |
|
|
| h = image_size // patch_size |
| w = image_size // patch_size |
| if h * w != tok.size(1): |
| raise RuntimeError( |
| f"Token count {tok.size(1)} != h*w ({h}*{w}={h*w}). " |
| f"image_size={image_size}, patch_size={patch_size}." |
| ) |
|
|
| img = einops.rearrange( |
| tok, |
| "b (hh ww) (p1 p2 c) -> b (hh p1) (ww p2) c", |
| p1=patch_size, |
| p2=patch_size, |
| hh=h, |
| ww=w, |
| c=n_chan, |
| ) |
| return img.to(torch.float32).detach().cpu().numpy() |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| model, modality_registry, checkpoint, ctx = load_model(out_dir, device, dtype, compile) |
|
|
| |
| |
| |
| train_cfg = checkpoint.get("config", {}) if isinstance(checkpoint, dict) else {} |
| spiral = bool(train_cfg.get("spiral", True)) |
| block_size = int(train_cfg.get("block_size", checkpoint["model_args"].get("block_size", 1024))) |
|
|
| |
| prefix_len = int(min(prefix_len, block_size)) |
|
|
| |
| transforms_map = {"images": data_transforms(use_hf=True)} |
| galproc = GalaxyImageDataset( |
| paths=None, |
| spiral=spiral, |
| transform=transforms_map, |
| modality_registry=modality_registry, |
| ) |
|
|
| |
| img_cfg = modality_registry.get_config("images") |
| patch_size = int(img_cfg.patch_size) |
|
|
| |
| |
| |
| print("\n[1/2] Creating 12-image reconstruction panel (orig left, recon right)...") |
|
|
| ds_recon = build_hf_stream( |
| dataset_name=dataset_name, |
| split=recon_split, |
| galproc=galproc, |
| streaming=stream_hf_dataset, |
| revision=revision, |
| ) |
|
|
| |
| samples: List[Dict[str, Any]] = [] |
| it = iter(ds_recon) |
| while len(samples) < n_recon: |
| samples.append(next(it)) |
|
|
| |
| images = torch.stack([s["images"] for s in samples], dim=0).to(device) |
| positions = torch.stack([s["images_positions"] for s in samples], dim=0).to(device) |
|
|
| |
| T = images.size(1) |
| side = int(round(math.sqrt(T))) |
| if side * side != T: |
| raise RuntimeError(f"Token length T={T} is not a perfect square; cannot infer image_size cleanly.") |
| image_size = side * patch_size |
|
|
| |
| X = {"images": images, "images_positions": positions} |
| Y = {"images": images, "images_positions": positions} |
|
|
| with torch.no_grad(): |
| with ctx: |
| P, loss = model(X, targets=Y) |
|
|
| |
| Y_img = tokens_to_images_validate_style( |
| tokens=Y["images"], |
| galproc=galproc, |
| spiral=spiral, |
| patch_size=patch_size, |
| image_size=image_size, |
| ) |
| P_img = tokens_to_images_validate_style( |
| tokens=P["images"], |
| galproc=galproc, |
| spiral=spiral, |
| patch_size=patch_size, |
| image_size=image_size, |
| ) |
|
|
| fig, axs = plt.subplots(n_recon, 2, figsize=(6, 3 * n_recon), constrained_layout=True) |
| if n_recon == 1: |
| axs = np.array([axs]) |
|
|
| for i in range(n_recon): |
| axs[i, 0].imshow(np.clip(Y_img[i], 0, 1)) |
| axs[i, 0].axis("off") |
| axs[i, 0].set_title("Original") |
|
|
| axs[i, 1].imshow(np.clip(P_img[i], 0, 1)) |
| axs[i, 1].axis("off") |
| axs[i, 1].set_title("Reconstructed") |
|
|
| recon_path = os.path.join(out_dir, save_recon_name) |
| fig.savefig(recon_path, dpi=150) |
| plt.close(fig) |
| print(f"Saved recon panel: {recon_path}") |
|
|
| |
| |
| |
| print("\n[2/2] Extracting embeddings (test + validation) and saving .npy...") |
|
|
| zss_chunks = [] |
| ids_chunks = [] |
|
|
| with torch.no_grad(): |
| with ctx: |
| for split in splits_for_embeddings: |
| print(f" -> split: {split}") |
|
|
| ds_embed = build_hf_stream( |
| dataset_name=dataset_name, |
| split=split, |
| galproc=galproc, |
| streaming=stream_hf_dataset, |
| revision=revision, |
| ) |
|
|
| dl = DataLoader( |
| ds_embed, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| ) |
|
|
| tt = tqdm(total=None, unit="galaxies", unit_scale=True) |
| for B in dl: |
| xs = B["images"][:, :prefix_len].to(device) |
| pos = B["images_positions"][:, :prefix_len].to(device) |
|
|
| inputs = {"images": xs, "images_positions": pos} |
|
|
| zs = model.generate_embeddings(inputs, reduction=embed_reduction) |
| zss_chunks.append(zs["images"].detach().cpu().numpy()) |
|
|
| |
| if id_field_name in B: |
| |
| ids = np.array(B[id_field_name], dtype=object) |
| else: |
| ids = np.array(["-1"] * xs.size(0), dtype=object) |
| ids_chunks.append(ids) |
|
|
| tt.update(xs.size(0)) |
| tt.close() |
|
|
| zss = np.concatenate(zss_chunks, axis=0) |
| ids = np.concatenate(ids_chunks, axis=0) |
|
|
| emb_path = os.path.join(out_dir, f"zss_{prefix_len}t_{embed_reduction}.npy") |
| ids_path = os.path.join(out_dir, f"idxs_{prefix_len}t_{embed_reduction}.npy") |
|
|
| np.save(emb_path, zss) |
| np.save(ids_path, ids) |
|
|
| print(f"Saved embeddings: {zss.shape}") |
| print(f" - {emb_path}") |
| print(f"Saved ids: {ids.shape} (dtype={ids.dtype})") |
| print(f" - {ids_path}") |
| print("Done.") |
|
|