astropt_aim / sample_embeddings.py
SogolS's picture
Upload sample_embeddings.py with huggingface_hub
4563ea9 verified
"""
Sample embeddings + reconstructions from a trained AstroPT model (AIM-compatible)
DETERMINISTIC / NO-SHUFFLE VERSION:
- Pulls from HuggingFace "Smith42/galaxies" using streaming=True
- DOES NOT shuffle at all
- Iterates in the exact order HuggingFace yields examples (shard/file order)
- Generates 12 reconstructions FIRST using the first 12 examples of `recon_split`
- Then extracts embeddings from BOTH `test` + `validation` in order and saves .npy
- Saves IDs from HF column `dr8_id` in a separate file:
- idxs_...npy (dtype=str/object), aligned 1:1 with embeddings rows
- Avoids concatenate_datasets() for streaming datasets (prevents PyArrow crash with torch tensors)
Notes:
- If you want the strictest “exact order” behavior, set num_workers = 0.
- The recon figure uses validate()-style rendering (prepend zero token + optional antispiralise).
"""
import os
import math
import functools
from contextlib import nullcontext
from typing import Dict, Any, List, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import einops
from tqdm import tqdm
from datasets import load_dataset
from astropt.model import GPT, GPTConfig
from astropt.local_datasets import GalaxyImageDataset
# -----------------------------------------------------------------------------
# Config
# -----------------------------------------------------------------------------
# checkpoint/log dir
out_dir = "/mnt/c/Users/shaha/Downloads/AIM" # where ckpt.pt lives + where outputs will be written
# compute
device = "cuda"
dtype = "bfloat16"
compile = False
# HF dataset
dataset_name = "Smith42/galaxies"
stream_hf_dataset = True # keep streaming
revision = None # set e.g. "v2.0" if you need a pinned revision
# Embedding extraction
splits_for_embeddings = ["test", "validation"] # as requested
batch_size = 256
num_workers = 0 # set to 0 for strictest order guarantee
pin_memory = True
prefix_len = 64 # number of image-tokens used for embeddings
embed_reduction = "mean" # "mean" | "last" | "exp_decay" | "none"
# Reconstruction figure
n_recon = 12
recon_split = "test"
save_recon_name = "recon_12.png"
# If your HF dataset uses 'dr8_id' (string), we save it. If missing, we store "-1".
id_field_name = "dr8_id"
# -----------------------------------------------------------------------------
# Transforms (match training)
# -----------------------------------------------------------------------------
def normalise(x: torch.Tensor, use_hf: bool = False) -> torch.Tensor:
# HF is in numpy format. Need to change that here if so:
if use_hf and isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(torch.float32)
std, mean = torch.std_mean(x, dim=1, keepdim=True)
x_norm = (x - mean) / (std + 1e-8)
return x_norm.to(torch.float16)
def data_transforms(use_hf: bool):
return transforms.Compose(
[
transforms.Lambda(functools.partial(normalise, use_hf=use_hf)),
]
)
def process_galaxy_wrapper(galdict: Dict[str, Any], func):
"""Wrapper for processing galaxy images from HF dataset (MATCH TRAINING)."""
patch_galaxy = func(np.array(galdict["image"]).swapaxes(0, 2))
return {
"images": patch_galaxy.to(torch.float),
"images_positions": torch.arange(0, len(patch_galaxy), dtype=torch.long),
# keep ID as string if present
id_field_name: galdict.get(id_field_name, "-1"),
}
# -----------------------------------------------------------------------------
# Model loading
# -----------------------------------------------------------------------------
def load_model(out_dir: str, device: str, dtype: str, compile_model: bool):
ckpt_path = os.path.join(out_dir, "ckpt.pt")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
modality_registry = checkpoint["modality_registry"]
gptconf = GPTConfig(**checkpoint["model_args"])
model = GPT(gptconf, modality_registry)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval().to(device)
if compile_model:
model = torch.compile(model)
device_type = "cuda" if "cuda" in device else "cpu"
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
return model, modality_registry, checkpoint, ctx
# -----------------------------------------------------------------------------
# HF dataset builder (match training pipeline)
# -----------------------------------------------------------------------------
def build_hf_stream(
dataset_name: str,
split: str,
galproc: GalaxyImageDataset,
streaming: bool,
revision: Optional[str] = None,
):
kwargs = dict(split=split, streaming=streaming)
if revision is not None:
kwargs["revision"] = revision
ds = load_dataset(dataset_name, **kwargs)
# IMPORTANT: keep id_field_name; only select image_crop for processing, but do NOT drop other columns
# We do this by selecting both "image_crop" and id_field_name if present.
# However, with streaming datasets, select_columns() will error if a column doesn't exist,
# so we handle it by trying and falling back.
try:
ds = ds.select_columns(["image_crop", id_field_name])
except Exception:
# If id_field_name doesn't exist, just select image_crop
ds = ds.select_columns(["image_crop"])
ds = ds.rename_column("image_crop", "image")
ds = ds.map(functools.partial(process_galaxy_wrapper, func=galproc.process_galaxy))
# process_galaxy_wrapper returns "images", "images_positions", and id_field_name
# remove raw "image" to keep batches light
ds = ds.remove_columns("image")
return ds
# -----------------------------------------------------------------------------
# validate()-style token->image conversion (shift + spiral handling)
# -----------------------------------------------------------------------------
def tokens_to_images_validate_style(
tokens: torch.Tensor, # [B, T, patch_dim]
galproc: GalaxyImageDataset,
spiral: bool,
patch_size: int,
image_size: int,
) -> np.ndarray:
"""
Mirrors validate() behavior:
- prepend a zero_block at t=0
- if spiral: antispiralise per sample
- rearrange patches into [B, H, W, C]
"""
B, T, D = tokens.shape
zero_block = torch.zeros((B, 1, D), device=tokens.device, dtype=tokens.dtype)
tok = torch.cat((zero_block, tokens), dim=1) # [B, T+1, D]
if spiral:
tok = torch.stack([galproc.antispiralise(yy) for yy in tok])
# Drop the prepended token so we can render a clean h*w grid
tok = tok[:, 1:, :] # [B, h*w, D]
n_chan = D // (patch_size * patch_size)
if (patch_size * patch_size * n_chan) != D:
raise RuntimeError(
f"Cannot factor patch_dim={D} into patch_size^2 * n_chan (patch_size={patch_size})."
)
h = image_size // patch_size
w = image_size // patch_size
if h * w != tok.size(1):
raise RuntimeError(
f"Token count {tok.size(1)} != h*w ({h}*{w}={h*w}). "
f"image_size={image_size}, patch_size={patch_size}."
)
img = einops.rearrange(
tok,
"b (hh ww) (p1 p2 c) -> b (hh p1) (ww p2) c",
p1=patch_size,
p2=patch_size,
hh=h,
ww=w,
c=n_chan,
)
return img.to(torch.float32).detach().cpu().numpy()
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
if __name__ == "__main__":
os.makedirs(out_dir, exist_ok=True)
model, modality_registry, checkpoint, ctx = load_model(out_dir, device, dtype, compile)
# -------------------------------------------------------------------------
# Match training config when available
# -------------------------------------------------------------------------
train_cfg = checkpoint.get("config", {}) if isinstance(checkpoint, dict) else {}
spiral = bool(train_cfg.get("spiral", True))
block_size = int(train_cfg.get("block_size", checkpoint["model_args"].get("block_size", 1024)))
# Clamp prefix_len to block_size for safety
prefix_len = int(min(prefix_len, block_size))
# Build GalaxyImageDataset processor (same tokenization/spiral ops as training)
transforms_map = {"images": data_transforms(use_hf=True)}
galproc = GalaxyImageDataset(
paths=None,
spiral=spiral,
transform=transforms_map,
modality_registry=modality_registry,
)
# Modality params
img_cfg = modality_registry.get_config("images")
patch_size = int(img_cfg.patch_size)
# -------------------------------------------------------------------------
# [1/2] Reconstructions FIRST (NO SHUFFLE => first 12 items in HF order)
# -------------------------------------------------------------------------
print("\n[1/2] Creating 12-image reconstruction panel (orig left, recon right)...")
ds_recon = build_hf_stream(
dataset_name=dataset_name,
split=recon_split,
galproc=galproc,
streaming=stream_hf_dataset,
revision=revision,
)
# Pull first 12 samples in HF order
samples: List[Dict[str, Any]] = []
it = iter(ds_recon)
while len(samples) < n_recon:
samples.append(next(it))
# Stack into a batch: images [B,T,D], positions [B,T]
images = torch.stack([s["images"] for s in samples], dim=0).to(device)
positions = torch.stack([s["images_positions"] for s in samples], dim=0).to(device)
# Infer image_size from token length T = (H/patch)*(W/patch)
T = images.size(1)
side = int(round(math.sqrt(T)))
if side * side != T:
raise RuntimeError(f"Token length T={T} is not a perfect square; cannot infer image_size cleanly.")
image_size = side * patch_size
# Teacher-forced forward: P, loss = model(X, targets=Y)
X = {"images": images, "images_positions": positions}
Y = {"images": images, "images_positions": positions}
with torch.no_grad():
with ctx:
P, loss = model(X, targets=Y)
# validate()-style visualization
Y_img = tokens_to_images_validate_style(
tokens=Y["images"],
galproc=galproc,
spiral=spiral,
patch_size=patch_size,
image_size=image_size,
)
P_img = tokens_to_images_validate_style(
tokens=P["images"],
galproc=galproc,
spiral=spiral,
patch_size=patch_size,
image_size=image_size,
)
fig, axs = plt.subplots(n_recon, 2, figsize=(6, 3 * n_recon), constrained_layout=True)
if n_recon == 1:
axs = np.array([axs])
for i in range(n_recon):
axs[i, 0].imshow(np.clip(Y_img[i], 0, 1))
axs[i, 0].axis("off")
axs[i, 0].set_title("Original")
axs[i, 1].imshow(np.clip(P_img[i], 0, 1))
axs[i, 1].axis("off")
axs[i, 1].set_title("Reconstructed")
recon_path = os.path.join(out_dir, save_recon_name)
fig.savefig(recon_path, dpi=150)
plt.close(fig)
print(f"Saved recon panel: {recon_path}")
# -------------------------------------------------------------------------
# [2/2] Embedding extraction SECOND (NO SHUFFLE, NO CONCAT for streaming)
# -------------------------------------------------------------------------
print("\n[2/2] Extracting embeddings (test + validation) and saving .npy...")
zss_chunks = []
ids_chunks = []
with torch.no_grad():
with ctx:
for split in splits_for_embeddings:
print(f" -> split: {split}")
ds_embed = build_hf_stream(
dataset_name=dataset_name,
split=split,
galproc=galproc,
streaming=stream_hf_dataset,
revision=revision,
)
dl = DataLoader(
ds_embed,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
)
tt = tqdm(total=None, unit="galaxies", unit_scale=True)
for B in dl:
xs = B["images"][:, :prefix_len].to(device)
pos = B["images_positions"][:, :prefix_len].to(device)
inputs = {"images": xs, "images_positions": pos}
zs = model.generate_embeddings(inputs, reduction=embed_reduction)
zss_chunks.append(zs["images"].detach().cpu().numpy())
# Save IDs (strings) aligned with embeddings
if id_field_name in B:
# B[id_field_name] can be list[str] or np array depending on HF formatting
ids = np.array(B[id_field_name], dtype=object)
else:
ids = np.array(["-1"] * xs.size(0), dtype=object)
ids_chunks.append(ids)
tt.update(xs.size(0))
tt.close()
zss = np.concatenate(zss_chunks, axis=0)
ids = np.concatenate(ids_chunks, axis=0)
emb_path = os.path.join(out_dir, f"zss_{prefix_len}t_{embed_reduction}.npy")
ids_path = os.path.join(out_dir, f"idxs_{prefix_len}t_{embed_reduction}.npy")
np.save(emb_path, zss)
np.save(ids_path, ids)
print(f"Saved embeddings: {zss.shape}")
print(f" - {emb_path}")
print(f"Saved ids: {ids.shape} (dtype={ids.dtype})")
print(f" - {ids_path}")
print("Done.")