| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import subprocess |
| import sys |
|
|
| try: |
| import sympy |
| _ = sympy.core |
| except (ImportError, AttributeError): |
| print("Fixing sympy...") |
| subprocess.check_call( |
| [sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"] |
| ) |
| print(" sympy upgraded. Restart kernel if needed.") |
|
|
| import gc |
| import json |
| import math |
| import os |
| import shutil |
| import time |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import ( |
| Dataset as HFDataset, |
| Features, |
| Sequence, |
| Value, |
| Array2D, |
| concatenate_datasets, |
| load_dataset, |
| load_from_disk, |
| ) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Stage1Config: |
| cache_dir: str = "/home/claude/geo_cache" |
| max_text_len: int = 32 |
| batch_size: int = 512 |
| num_workers: int = 8 |
| shard_size: int = 2048 |
| writer_batch_size: int = 256 |
| pin_memory: bool = torch.cuda.is_available() |
| prefetch_factor: int = 2 |
| cleanup_shards_after_merge: bool = True |
| print_every: int = 1000 |
|
|
|
|
| CFG = Stage1Config() |
|
|
|
|
| |
| |
| |
|
|
| def extract_caption(sample: Dict[str, Any]) -> str: |
| """ |
| Deterministic caption extraction. |
| Keeps your original heuristic, but isolates it for clarity and future replacement. |
| """ |
| for key in ["answer", "caption", "captions", "sentences", "text"]: |
| if key not in sample: |
| continue |
|
|
| val = sample[key] |
|
|
| if isinstance(val, str): |
| caption = val.strip() |
| if caption: |
| return caption |
|
|
| if isinstance(val, list) and val: |
| item = val[0] |
|
|
| if isinstance(item, str): |
| caption = item.strip() |
| if caption: |
| return caption |
|
|
| if isinstance(item, dict): |
| caption = str(item.get("raw", item.get("text", ""))).strip() |
| if caption: |
| return caption |
|
|
| caption = str(item).strip() |
| if caption: |
| return caption |
|
|
| return "" |
|
|
|
|
| def make_dataloader(dataset: Dataset, batch_size: int, num_workers: int = 8, shuffle: bool = False) -> DataLoader: |
| """DataLoader with pinned memory and prefetch.""" |
| kwargs = dict( |
| dataset=dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| pin_memory=CFG.pin_memory, |
| persistent_workers=num_workers > 0, |
| ) |
| if num_workers > 0: |
| kwargs["prefetch_factor"] = CFG.prefetch_factor |
| return DataLoader(**kwargs) |
|
|
|
|
| def flush_shard( |
| shard_root: str, |
| shard_index: int, |
| features: Features, |
| shard_rows: Dict[str, List[Any]], |
| writer_batch_size: int, |
| ) -> Optional[str]: |
| """ |
| Flush one shard to disk and clear in-memory shard rows. |
| """ |
| n_rows = len(shard_rows["source_idx"]) |
| if n_rows == 0: |
| return None |
|
|
| shard_path = os.path.join(shard_root, f"shard_{shard_index:05d}") |
| os.makedirs(shard_root, exist_ok=True) |
|
|
| ds = HFDataset.from_dict(shard_rows, features=features) |
| ds.save_to_disk(shard_path) |
|
|
| return shard_path |
|
|
|
|
| def reset_shard_rows() -> Dict[str, List[Any]]: |
| return { |
| "source_idx": [], |
| "text_hidden": [], |
| "text_mask": [], |
| "image_hidden": [], |
| } |
|
|
|
|
| def write_manifest(path: str, data: Dict[str, Any]) -> None: |
| with open(path, "w") as f: |
| json.dump(data, f, indent=2) |
|
|
|
|
| |
| |
| |
|
|
| class ImageTextDataset(Dataset): |
| """ |
| Wraps an HF dataset. __getitem__ does ALL CPU work: |
| caption extraction, tokenization, image processing. |
| DataLoader workers call this in parallel. |
| Returns tensors ready for GPU forward, plus source index and validity flag. |
| """ |
|
|
| def __init__(self, hf_dataset, tokenizer, image_processor, max_text_len: int): |
| self.ds = hf_dataset |
| self.tok = tokenizer |
| self.proc = image_processor |
| self.max_text_len = max_text_len |
|
|
| |
| |
| self.fallback_pixel_shape = self._infer_fallback_pixel_shape() |
|
|
| def _infer_fallback_pixel_shape(self) -> Tuple[int, int, int]: |
| |
| |
| size = getattr(self.proc, "size", None) |
| if isinstance(size, dict): |
| h = size.get("height", size.get("shortest_edge", 518)) |
| w = size.get("width", size.get("shortest_edge", 518)) |
| return (3, int(h), int(w)) |
| return (3, 518, 518) |
|
|
| def __len__(self): |
| return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| sample = self.ds[idx] |
|
|
| |
| caption = extract_caption(sample) |
|
|
| |
| tokens = self.tok( |
| caption, |
| padding="max_length", |
| truncation=True, |
| max_length=self.max_text_len, |
| return_tensors="pt", |
| ) |
| input_ids = tokens["input_ids"].squeeze(0) |
| attn_mask = tokens["attention_mask"].squeeze(0) |
|
|
| |
| image = sample.get("image", None) |
| valid = True |
|
|
| if image is not None and hasattr(image, "convert"): |
| try: |
| image = image.convert("RGB") |
| pixel_values = self.proc(images=image, return_tensors="pt")["pixel_values"].squeeze(0) |
| except Exception: |
| pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32) |
| valid = False |
| else: |
| pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32) |
| valid = False |
|
|
| return idx, input_ids, attn_mask, pixel_values, valid |
|
|
|
|
| |
| |
| |
|
|
| def process_and_cache( |
| dataset_id: str, |
| split: str, |
| max_samples: Optional[int], |
| batch_size: int = 512, |
| num_workers: int = 8, |
| shard_size: int = 2048, |
| tag: Optional[str] = None, |
| bert=None, |
| dino=None, |
| tokenizer=None, |
| processor=None, |
| ) -> str: |
| """ |
| Full pipeline: |
| 1. load_dataset β HF Dataset |
| 2. Wrap in torch Dataset (tokenize + image process in workers) |
| 3. DataLoader β GPU encode |
| 4. Write shard-safe Arrow datasets |
| 5. Concatenate shards β save final dataset |
| """ |
| assert bert is not None |
| assert dino is not None |
| assert tokenizer is not None |
| assert processor is not None |
|
|
| tag = tag or f"{dataset_id.replace('/', '_')}_{split}" |
| cache_path = os.path.join(CFG.cache_dir, tag) |
| shard_root = os.path.join(CFG.cache_dir, f"{tag}__shards") |
| manifest_path = os.path.join(CFG.cache_dir, f"{tag}__manifest.json") |
|
|
| if os.path.exists(cache_path): |
| print(f" Cache exists: {cache_path}") |
| ds = load_from_disk(cache_path) |
| print(f" {len(ds)} samples cached") |
| return cache_path |
|
|
| os.makedirs(CFG.cache_dir, exist_ok=True) |
|
|
| print(f"\n Loading {dataset_id} ({split})...") |
| t0 = time.time() |
| hf_ds = load_dataset(dataset_id, split=split) |
| raw_total = len(hf_ds) |
| print(f" Dataset: {raw_total} samples") |
|
|
| |
| if max_samples is not None and raw_total > max_samples: |
| hf_ds = hf_ds.select(range(max_samples)) |
| print(f" Truncated raw dataset to {len(hf_ds)}") |
| raw_total = len(hf_ds) |
|
|
| first = hf_ds[0] |
| print(f" Columns: {list(first.keys())}") |
|
|
| torch_ds = ImageTextDataset( |
| hf_dataset=hf_ds, |
| tokenizer=tokenizer, |
| image_processor=processor, |
| max_text_len=CFG.max_text_len, |
| ) |
|
|
| loader = make_dataloader( |
| dataset=torch_ds, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| shuffle=False, |
| ) |
|
|
| print(" Encoding...") |
|
|
| feature_schema: Optional[Features] = None |
| shard_rows = reset_shard_rows() |
| shard_paths: List[str] = [] |
| shard_index = 0 |
|
|
| raw_seen = 0 |
| valid_saved = 0 |
| invalid_dropped = 0 |
|
|
| for batch in loader: |
| source_idx, input_ids, attn_mask, pixel_values, valid = batch |
|
|
| batch_raw = int(input_ids.shape[0]) |
| raw_seen += batch_raw |
|
|
| valid_b = valid.bool() |
| invalid_dropped += int((~valid_b).sum().item()) |
|
|
| if not valid_b.any(): |
| if raw_seen % CFG.print_every < batch_raw or raw_seen <= batch_raw: |
| rate = raw_seen / max(time.time() - t0, 1e-6) |
| print(f" raw={raw_seen}/{raw_total} valid={valid_saved} invalid={invalid_dropped} ({rate:.0f} raw/s)") |
| continue |
|
|
| source_idx_v = source_idx[valid_b] |
| input_ids_v = input_ids[valid_b].to(device, non_blocking=True) |
| attn_mask_v = attn_mask[valid_b].to(device, non_blocking=True) |
| pixel_values_v = pixel_values[valid_b].to(device, non_blocking=True) |
|
|
| with torch.no_grad(): |
| if torch.cuda.is_available(): |
| with torch.amp.autocast("cuda", enabled=True): |
| text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state |
| image_h = dino(pixel_values=pixel_values_v).last_hidden_state |
| else: |
| text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state |
| image_h = dino(pixel_values=pixel_values_v).last_hidden_state |
|
|
| text_h = text_h.detach().to(dtype=torch.float16).cpu().numpy() |
| text_m = attn_mask_v.bool().cpu().numpy() |
| image_h = image_h.detach().to(dtype=torch.float16).cpu().numpy() |
| source_idx_np = source_idx_v.cpu().numpy().astype(np.int64) |
|
|
| |
| if feature_schema is None: |
| text_shape = tuple(text_h.shape[1:]) |
| image_shape = tuple(image_h.shape[1:]) |
|
|
| feature_schema = Features({ |
| "source_idx": Value("int64"), |
| "text_hidden": Array2D(shape=text_shape, dtype="float16"), |
| "text_mask": Sequence(Value("bool"), length=text_shape[0]), |
| "image_hidden": Array2D(shape=image_shape, dtype="float16"), |
| }) |
|
|
| print(f" Feature schema:") |
| print(f" text_hidden: {text_shape} float16") |
| print(f" text_mask: ({text_shape[0]},) bool") |
| print(f" image_hidden:{image_shape} float16") |
|
|
| |
| for i in range(text_h.shape[0]): |
| shard_rows["source_idx"].append(int(source_idx_np[i])) |
| shard_rows["text_hidden"].append(text_h[i]) |
| shard_rows["text_mask"].append(text_m[i].tolist()) |
| shard_rows["image_hidden"].append(image_h[i]) |
|
|
| valid_saved += int(text_h.shape[0]) |
|
|
| if valid_saved % CFG.print_every < text_h.shape[0] or valid_saved <= text_h.shape[0]: |
| rate = raw_seen / max(time.time() - t0, 1e-6) |
| print( |
| f" raw={raw_seen}/{raw_total} valid={valid_saved} " |
| f"invalid={invalid_dropped} ({rate:.0f} raw/s)" |
| ) |
|
|
| if len(shard_rows["source_idx"]) >= shard_size: |
| shard_path = flush_shard( |
| shard_root=shard_root, |
| shard_index=shard_index, |
| features=feature_schema, |
| shard_rows=shard_rows, |
| writer_batch_size=CFG.writer_batch_size, |
| ) |
| if shard_path is not None: |
| shard_paths.append(shard_path) |
| print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(shard_path))} rows)") |
| shard_index += 1 |
| shard_rows = reset_shard_rows() |
|
|
| |
| if feature_schema is None: |
| raise RuntimeError("No valid samples were encoded. Cannot build cache.") |
|
|
| tail_path = flush_shard( |
| shard_root=shard_root, |
| shard_index=shard_index, |
| features=feature_schema, |
| shard_rows=shard_rows, |
| writer_batch_size=CFG.writer_batch_size, |
| ) |
| if tail_path is not None: |
| shard_paths.append(tail_path) |
| print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(tail_path))} rows)") |
|
|
| |
| print(" Merging shards...") |
| shard_datasets = [load_from_disk(p) for p in shard_paths] |
| result_ds = concatenate_datasets(shard_datasets) |
| result_ds.save_to_disk(cache_path) |
|
|
| elapsed = time.time() - t0 |
| print(f" Saved {len(result_ds)} samples to {cache_path} ({elapsed:.0f}s)") |
|
|
| manifest = { |
| "dataset_id": dataset_id, |
| "split": split, |
| "tag": tag, |
| "cache_path": cache_path, |
| "raw_total_considered": raw_total, |
| "raw_seen": raw_seen, |
| "valid_saved": valid_saved, |
| "invalid_dropped": invalid_dropped, |
| "invalid_rate": (invalid_dropped / raw_seen) if raw_seen > 0 else 0.0, |
| "num_shards": len(shard_paths), |
| "feature_schema": { |
| "text_hidden_shape": list(feature_schema["text_hidden"].shape), |
| "text_mask_len": feature_schema["text_mask"].length, |
| "image_hidden_shape": list(feature_schema["image_hidden"].shape), |
| }, |
| "elapsed_sec": elapsed, |
| } |
| write_manifest(manifest_path, manifest) |
| print(f" Wrote manifest: {manifest_path}") |
|
|
| |
| if CFG.cleanup_shards_after_merge and os.path.exists(shard_root): |
| shutil.rmtree(shard_root, ignore_errors=True) |
| print(f" Removed temporary shards: {shard_root}") |
|
|
| |
| del result_ds |
| del shard_datasets |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return cache_path |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| os.makedirs(CFG.cache_dir, exist_ok=True) |
|
|
| print("=" * 70) |
| print("STAGE 1: PRECOMPUTE EMBEDDINGS") |
| print("=" * 70) |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name()}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| print(f"Cache dir: {CFG.cache_dir}") |
|
|
| |
| print("\nLoading encoders...") |
| from transformers import BertModel, BertTokenizer, Dinov2Model, AutoImageProcessor |
|
|
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") |
| processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large") |
|
|
| bert = BertModel.from_pretrained( |
| "google-bert/bert-large-uncased", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(device).eval() |
|
|
| dino = Dinov2Model.from_pretrained( |
| "facebook/dinov2-large", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(device).eval() |
|
|
| print(" Encoders ready.") |
|
|
| paths = {} |
|
|
| |
| print(f"\n{'β' * 50}") |
| print("[1/3] COCO-Caption val (training) β FULL") |
| paths["coco_val"] = process_and_cache( |
| dataset_id="lmms-lab/COCO-Caption", |
| split="val", |
| max_samples=None, |
| batch_size=CFG.batch_size, |
| num_workers=CFG.num_workers, |
| shard_size=CFG.shard_size, |
| tag="coco_val", |
| bert=bert, |
| dino=dino, |
| tokenizer=tokenizer, |
| processor=processor, |
| ) |
|
|
| |
| print(f"\n{'β' * 50}") |
| print("[2/3] COCO-Caption test (held-out) β FULL") |
| paths["coco_test"] = process_and_cache( |
| dataset_id="lmms-lab/COCO-Caption", |
| split="test", |
| max_samples=None, |
| batch_size=CFG.batch_size, |
| num_workers=CFG.num_workers, |
| shard_size=CFG.shard_size, |
| tag="coco_test", |
| bert=bert, |
| dino=dino, |
| tokenizer=tokenizer, |
| processor=processor, |
| ) |
|
|
| |
| print(f"\n{'β' * 50}") |
| print("[3/3] Flickr30k (cross-dataset) β FULL") |
| try: |
| paths["flickr"] = process_and_cache( |
| dataset_id="Mozilla/flickr30k-transformed-captions", |
| split="test", |
| max_samples=None, |
| batch_size=CFG.batch_size, |
| num_workers=CFG.num_workers, |
| shard_size=CFG.shard_size, |
| tag="flickr30k", |
| bert=bert, |
| dino=dino, |
| tokenizer=tokenizer, |
| processor=processor, |
| ) |
| except Exception as e: |
| print(f" Flickr30k failed: {e}") |
| paths["flickr"] = None |
|
|
| |
| del bert, dino, tokenizer, processor |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| print(f"\n{'=' * 70}") |
| print("CACHE SUMMARY") |
| print(f"{'=' * 70}") |
| for name, path in paths.items(): |
| if path and os.path.exists(path): |
| ds = load_from_disk(path) |
| print(f" {name:15s}: {len(ds):6d} samples [{path}]") |
|
|
| print(f"\n Stage 2 usage:") |
| print(f' ds = load_from_disk("{CFG.cache_dir}/coco_val").with_format("torch")') |
| print(f' loader = DataLoader(ds, batch_size=64, num_workers=4)') |
| print("\nDone.") |