geolip-procrustes / data_prep.py
AbstractPhil's picture
Update data_prep.py
74b9407 verified
# ============================================================================
# STAGE 1: PRECOMPUTE EMBEDDINGS β€” DATALOADER PIPELINE (CORRECTED) -> flikr fixed
#
# Architecture:
# HF load_dataset
# β†’ custom torch.Dataset (__getitem__ does CPU tokenization + image processing)
# β†’ DataLoader (workers do CPU I/O)
# β†’ GPU encode
# β†’ shard-safe HF Arrow writes
# β†’ concatenate shards
# β†’ save_to_disk final dataset
# ============================================================================
# Fix broken sympy before torch imports it
import subprocess
import sys
try:
import sympy
_ = sympy.core
except (ImportError, AttributeError):
print("Fixing sympy...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"]
)
print(" sympy upgraded. Restart kernel if needed.")
import gc
import json
import math
import os
import shutil
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import (
Dataset as HFDataset,
Features,
Sequence,
Value,
Array2D,
concatenate_datasets,
load_dataset,
load_from_disk,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ══════════════════════════════════════════════════════════════════
# CONFIG
# ══════════════════════════════════════════════════════════════════
@dataclass
class Stage1Config:
cache_dir: str = "/home/claude/geo_cache"
max_text_len: int = 32
batch_size: int = 512
num_workers: int = 8
shard_size: int = 2048 # number of valid encoded samples per shard
writer_batch_size: int = 256 # HF internal writer batch size
pin_memory: bool = torch.cuda.is_available()
prefetch_factor: int = 2
cleanup_shards_after_merge: bool = True
print_every: int = 1000
CFG = Stage1Config()
# ══════════════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════════════
def extract_caption(sample: Dict[str, Any]) -> str:
"""
Deterministic caption extraction.
Keeps your original heuristic, but isolates it for clarity and future replacement.
"""
for key in ["answer", "caption", "captions", "sentences", "text"]:
if key not in sample:
continue
val = sample[key]
if isinstance(val, str):
caption = val.strip()
if caption:
return caption
if isinstance(val, list) and val:
item = val[0]
if isinstance(item, str):
caption = item.strip()
if caption:
return caption
if isinstance(item, dict):
caption = str(item.get("raw", item.get("text", ""))).strip()
if caption:
return caption
caption = str(item).strip()
if caption:
return caption
return ""
def make_dataloader(dataset: Dataset, batch_size: int, num_workers: int = 8, shuffle: bool = False) -> DataLoader:
"""DataLoader with pinned memory and prefetch."""
kwargs = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=CFG.pin_memory,
persistent_workers=num_workers > 0,
)
if num_workers > 0:
kwargs["prefetch_factor"] = CFG.prefetch_factor
return DataLoader(**kwargs)
def flush_shard(
shard_root: str,
shard_index: int,
features: Features,
shard_rows: Dict[str, List[Any]],
writer_batch_size: int,
) -> Optional[str]:
"""
Flush one shard to disk and clear in-memory shard rows.
"""
n_rows = len(shard_rows["source_idx"])
if n_rows == 0:
return None
shard_path = os.path.join(shard_root, f"shard_{shard_index:05d}")
os.makedirs(shard_root, exist_ok=True)
ds = HFDataset.from_dict(shard_rows, features=features)
ds.save_to_disk(shard_path)
return shard_path
def reset_shard_rows() -> Dict[str, List[Any]]:
return {
"source_idx": [],
"text_hidden": [],
"text_mask": [],
"image_hidden": [],
}
def write_manifest(path: str, data: Dict[str, Any]) -> None:
with open(path, "w") as f:
json.dump(data, f, indent=2)
# ══════════════════════════════════════════════════════════════════
# TORCH DATASET β€” workers do tokenization + image processing
# ══════════════════════════════════════════════════════════════════
class ImageTextDataset(Dataset):
"""
Wraps an HF dataset. __getitem__ does ALL CPU work:
caption extraction, tokenization, image processing.
DataLoader workers call this in parallel.
Returns tensors ready for GPU forward, plus source index and validity flag.
"""
def __init__(self, hf_dataset, tokenizer, image_processor, max_text_len: int):
self.ds = hf_dataset
self.tok = tokenizer
self.proc = image_processor
self.max_text_len = max_text_len
# Determine expected pixel tensor shape once for invalid fallbacks.
# If processor output shape differs in practice, valid samples define the real downstream contract.
self.fallback_pixel_shape = self._infer_fallback_pixel_shape()
def _infer_fallback_pixel_shape(self) -> Tuple[int, int, int]:
# Dinov2 image processor usually produces 3x518x518 for this model family.
# We try to infer more cleanly when possible, otherwise fall back.
size = getattr(self.proc, "size", None)
if isinstance(size, dict):
h = size.get("height", size.get("shortest_edge", 518))
w = size.get("width", size.get("shortest_edge", 518))
return (3, int(h), int(w))
return (3, 518, 518)
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
sample = self.ds[idx]
# Caption
caption = extract_caption(sample)
# Tokenize (CPU)
tokens = self.tok(
caption,
padding="max_length",
truncation=True,
max_length=self.max_text_len,
return_tensors="pt",
)
input_ids = tokens["input_ids"].squeeze(0)
attn_mask = tokens["attention_mask"].squeeze(0)
# Image processing (CPU β€” resize, normalize, to tensor)
image = sample.get("image", None)
valid = True
if image is not None and hasattr(image, "convert"):
try:
image = image.convert("RGB")
pixel_values = self.proc(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
except Exception:
pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32)
valid = False
else:
pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32)
valid = False
return idx, input_ids, attn_mask, pixel_values, valid
# ══════════════════════════════════════════════════════════════════
# FULL PIPELINE
# ══════════════════════════════════════════════════════════════════
def process_and_cache(
dataset_id: str,
split: str,
max_samples: Optional[int],
batch_size: int = 512,
num_workers: int = 8,
shard_size: int = 2048,
tag: Optional[str] = None,
bert=None,
dino=None,
tokenizer=None,
processor=None,
) -> str:
"""
Full pipeline:
1. load_dataset β†’ HF Dataset
2. Wrap in torch Dataset (tokenize + image process in workers)
3. DataLoader β†’ GPU encode
4. Write shard-safe Arrow datasets
5. Concatenate shards β†’ save final dataset
"""
assert bert is not None
assert dino is not None
assert tokenizer is not None
assert processor is not None
tag = tag or f"{dataset_id.replace('/', '_')}_{split}"
cache_path = os.path.join(CFG.cache_dir, tag)
shard_root = os.path.join(CFG.cache_dir, f"{tag}__shards")
manifest_path = os.path.join(CFG.cache_dir, f"{tag}__manifest.json")
if os.path.exists(cache_path):
print(f" Cache exists: {cache_path}")
ds = load_from_disk(cache_path)
print(f" {len(ds)} samples cached")
return cache_path
os.makedirs(CFG.cache_dir, exist_ok=True)
print(f"\n Loading {dataset_id} ({split})...")
t0 = time.time()
hf_ds = load_dataset(dataset_id, split=split)
raw_total = len(hf_ds)
print(f" Dataset: {raw_total} samples")
# Truncate raw source dataset if requested
if max_samples is not None and raw_total > max_samples:
hf_ds = hf_ds.select(range(max_samples))
print(f" Truncated raw dataset to {len(hf_ds)}")
raw_total = len(hf_ds)
first = hf_ds[0]
print(f" Columns: {list(first.keys())}")
torch_ds = ImageTextDataset(
hf_dataset=hf_ds,
tokenizer=tokenizer,
image_processor=processor,
max_text_len=CFG.max_text_len,
)
loader = make_dataloader(
dataset=torch_ds,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
)
print(" Encoding...")
feature_schema: Optional[Features] = None
shard_rows = reset_shard_rows()
shard_paths: List[str] = []
shard_index = 0
raw_seen = 0
valid_saved = 0
invalid_dropped = 0
for batch in loader:
source_idx, input_ids, attn_mask, pixel_values, valid = batch
batch_raw = int(input_ids.shape[0])
raw_seen += batch_raw
valid_b = valid.bool()
invalid_dropped += int((~valid_b).sum().item())
if not valid_b.any():
if raw_seen % CFG.print_every < batch_raw or raw_seen <= batch_raw:
rate = raw_seen / max(time.time() - t0, 1e-6)
print(f" raw={raw_seen}/{raw_total} valid={valid_saved} invalid={invalid_dropped} ({rate:.0f} raw/s)")
continue
source_idx_v = source_idx[valid_b]
input_ids_v = input_ids[valid_b].to(device, non_blocking=True)
attn_mask_v = attn_mask[valid_b].to(device, non_blocking=True)
pixel_values_v = pixel_values[valid_b].to(device, non_blocking=True)
with torch.no_grad():
if torch.cuda.is_available():
with torch.amp.autocast("cuda", enabled=True):
text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state
image_h = dino(pixel_values=pixel_values_v).last_hidden_state
else:
text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state
image_h = dino(pixel_values=pixel_values_v).last_hidden_state
text_h = text_h.detach().to(dtype=torch.float16).cpu().numpy()
text_m = attn_mask_v.bool().cpu().numpy()
image_h = image_h.detach().to(dtype=torch.float16).cpu().numpy()
source_idx_np = source_idx_v.cpu().numpy().astype(np.int64)
# Establish explicit schema once from the first valid encoded batch.
if feature_schema is None:
text_shape = tuple(text_h.shape[1:])
image_shape = tuple(image_h.shape[1:])
feature_schema = Features({
"source_idx": Value("int64"),
"text_hidden": Array2D(shape=text_shape, dtype="float16"),
"text_mask": Sequence(Value("bool"), length=text_shape[0]),
"image_hidden": Array2D(shape=image_shape, dtype="float16"),
})
print(f" Feature schema:")
print(f" text_hidden: {text_shape} float16")
print(f" text_mask: ({text_shape[0]},) bool")
print(f" image_hidden:{image_shape} float16")
# Accumulate only the current shard in memory.
for i in range(text_h.shape[0]):
shard_rows["source_idx"].append(int(source_idx_np[i]))
shard_rows["text_hidden"].append(text_h[i])
shard_rows["text_mask"].append(text_m[i].tolist())
shard_rows["image_hidden"].append(image_h[i])
valid_saved += int(text_h.shape[0])
if valid_saved % CFG.print_every < text_h.shape[0] or valid_saved <= text_h.shape[0]:
rate = raw_seen / max(time.time() - t0, 1e-6)
print(
f" raw={raw_seen}/{raw_total} valid={valid_saved} "
f"invalid={invalid_dropped} ({rate:.0f} raw/s)"
)
if len(shard_rows["source_idx"]) >= shard_size:
shard_path = flush_shard(
shard_root=shard_root,
shard_index=shard_index,
features=feature_schema,
shard_rows=shard_rows,
writer_batch_size=CFG.writer_batch_size,
)
if shard_path is not None:
shard_paths.append(shard_path)
print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(shard_path))} rows)")
shard_index += 1
shard_rows = reset_shard_rows()
# Flush tail shard
if feature_schema is None:
raise RuntimeError("No valid samples were encoded. Cannot build cache.")
tail_path = flush_shard(
shard_root=shard_root,
shard_index=shard_index,
features=feature_schema,
shard_rows=shard_rows,
writer_batch_size=CFG.writer_batch_size,
)
if tail_path is not None:
shard_paths.append(tail_path)
print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(tail_path))} rows)")
# Merge shards into final dataset
print(" Merging shards...")
shard_datasets = [load_from_disk(p) for p in shard_paths]
result_ds = concatenate_datasets(shard_datasets)
result_ds.save_to_disk(cache_path)
elapsed = time.time() - t0
print(f" Saved {len(result_ds)} samples to {cache_path} ({elapsed:.0f}s)")
manifest = {
"dataset_id": dataset_id,
"split": split,
"tag": tag,
"cache_path": cache_path,
"raw_total_considered": raw_total,
"raw_seen": raw_seen,
"valid_saved": valid_saved,
"invalid_dropped": invalid_dropped,
"invalid_rate": (invalid_dropped / raw_seen) if raw_seen > 0 else 0.0,
"num_shards": len(shard_paths),
"feature_schema": {
"text_hidden_shape": list(feature_schema["text_hidden"].shape),
"text_mask_len": feature_schema["text_mask"].length,
"image_hidden_shape": list(feature_schema["image_hidden"].shape),
},
"elapsed_sec": elapsed,
}
write_manifest(manifest_path, manifest)
print(f" Wrote manifest: {manifest_path}")
# Cleanup shard directories if requested
if CFG.cleanup_shards_after_merge and os.path.exists(shard_root):
shutil.rmtree(shard_root, ignore_errors=True)
print(f" Removed temporary shards: {shard_root}")
# Free RAM/VRAM between datasets
del result_ds
del shard_datasets
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return cache_path
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
os.makedirs(CFG.cache_dir, exist_ok=True)
print("=" * 70)
print("STAGE 1: PRECOMPUTE EMBEDDINGS")
print("=" * 70)
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Cache dir: {CFG.cache_dir}")
# Load encoders ONCE β€” shared across all datasets
print("\nLoading encoders...")
from transformers import BertModel, BertTokenizer, Dinov2Model, AutoImageProcessor
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
bert = BertModel.from_pretrained(
"google-bert/bert-large-uncased",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()
dino = Dinov2Model.from_pretrained(
"facebook/dinov2-large",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()
print(" Encoders ready.")
paths = {}
# ── COCO val β€” FULL ──
print(f"\n{'─' * 50}")
print("[1/3] COCO-Caption val (training) β€” FULL")
paths["coco_val"] = process_and_cache(
dataset_id="lmms-lab/COCO-Caption",
split="val",
max_samples=None,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shard_size=CFG.shard_size,
tag="coco_val",
bert=bert,
dino=dino,
tokenizer=tokenizer,
processor=processor,
)
# ── COCO test β€” FULL ──
print(f"\n{'─' * 50}")
print("[2/3] COCO-Caption test (held-out) β€” FULL")
paths["coco_test"] = process_and_cache(
dataset_id="lmms-lab/COCO-Caption",
split="test",
max_samples=None,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shard_size=CFG.shard_size,
tag="coco_test",
bert=bert,
dino=dino,
tokenizer=tokenizer,
processor=processor,
)
# ── Flickr30k β€” FULL ──
print(f"\n{'─' * 50}")
print("[3/3] Flickr30k (cross-dataset) β€” FULL")
try:
paths["flickr"] = process_and_cache(
dataset_id="Mozilla/flickr30k-transformed-captions",
split="test",
max_samples=None,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shard_size=CFG.shard_size,
tag="flickr30k",
bert=bert,
dino=dino,
tokenizer=tokenizer,
processor=processor,
)
except Exception as e:
print(f" Flickr30k failed: {e}")
paths["flickr"] = None
# Unload
del bert, dino, tokenizer, processor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Summary
print(f"\n{'=' * 70}")
print("CACHE SUMMARY")
print(f"{'=' * 70}")
for name, path in paths.items():
if path and os.path.exists(path):
ds = load_from_disk(path)
print(f" {name:15s}: {len(ds):6d} samples [{path}]")
print(f"\n Stage 2 usage:")
print(f' ds = load_from_disk("{CFG.cache_dir}/coco_val").with_format("torch")')
print(f' loader = DataLoader(ds, batch_size=64, num_workers=4)')
print("\nDone.")