Spaces:
Sleeping
Sleeping
| """ | |
| step2_prepare_data.py | |
| ====================== | |
| Task 3 β Component 2: Prepare 500 COCO validation images for inference. | |
| Loads 500 images from the MS-COCO 2017 validation split (via HuggingFace | |
| Datasets) and wraps them in a standard PyTorch DataLoader. | |
| Public API | |
| ---------- | |
| load_val_data(processor, n=500, batch_size=8, seed=42) | |
| -> torch.utils.data.DataLoader | |
| Each batch yields a dict: | |
| { | |
| "pixel_values" : FloatTensor (B, 3, 384, 384), | |
| "labels" : LongTensor (B, max_len), # reference caption ids | |
| "captions" : list[str] # raw reference strings | |
| } | |
| Standalone usage | |
| ---------------- | |
| export PYTHONPATH=. | |
| venv/bin/python task/task_03/step2_prepare_data.py | |
| """ | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import BlipProcessor | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dataset wrapper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DATASET_ID = "nlphuji/flickr30k" # fallback if COCO unavailable | |
| COCO_ID = "phiyodr/coco2017" | |
| class COCOValDataset(Dataset): | |
| """ | |
| Wraps a HuggingFace dataset split into a torch Dataset. | |
| Args: | |
| hf_dataset : HuggingFace Dataset object with 'image' and 'captions' fields. | |
| processor : BlipProcessor instance. | |
| max_len : Maximum tokenization length for reference captions. | |
| """ | |
| def __init__(self, hf_dataset, processor: BlipProcessor, max_len: int = 64): | |
| self.data = hf_dataset | |
| self.processor = processor | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| example = self.data[idx] | |
| image = example["image"].convert("RGB") | |
| # Pick the first reference caption | |
| captions = example.get("captions", example.get("caption", ["<no caption>"])) | |
| if isinstance(captions, str): | |
| captions = [captions] | |
| caption = captions[0] | |
| enc = self.processor( | |
| images=image, | |
| text=caption, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_len, | |
| ) | |
| return { | |
| "pixel_values": enc["pixel_values"].squeeze(0), # (3, H, W) | |
| "labels": enc["input_ids"].squeeze(0), # (max_len,) | |
| "caption": caption, | |
| } | |
| def _collate_fn(batch): | |
| return { | |
| "pixel_values": torch.stack([b["pixel_values"] for b in batch]), | |
| "labels": torch.stack([b["labels"] for b in batch]), | |
| "captions": [b["caption"] for b in batch], | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Public loader | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_val_data( | |
| processor: BlipProcessor, | |
| n: int = 500, | |
| batch_size: int = 8, | |
| seed: int = 42, | |
| max_len: int = 64, | |
| ) -> DataLoader: | |
| """ | |
| Download and prepare n COCO validation images. | |
| Falls back to Flickr30k if COCO is unavailable (e.g. firewall/proxy). | |
| Args: | |
| processor : BlipProcessor (from step1_load_model) | |
| n : Number of validation images to use (default 500) | |
| batch_size : DataLoader batch size | |
| seed : Random seed for reproducible shuffle | |
| max_len : Max caption token length for labels | |
| Returns: | |
| DataLoader that yields batches with keys: | |
| pixel_values, labels, captions | |
| """ | |
| from datasets import load_dataset | |
| print("=" * 60) | |
| print(" Task 3 β Step 2: Prepare Validation Data") | |
| print("=" * 60) | |
| print(f" Target images : {n}") | |
| print(f" Batch size : {batch_size}") | |
| # ββ Try COCO first ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ds = None | |
| try: | |
| print(f" Loading dataset: {COCO_ID} ...") | |
| raw = load_dataset(COCO_ID, split="validation", trust_remote_code=True) | |
| ds = raw.shuffle(seed=seed).select(range(min(n, len(raw)))) | |
| print(f" β COCO loaded ({len(ds)} images)") | |
| except Exception as e: | |
| print(f" β οΈ COCO unavailable ({e}). Falling back to Flickr30k β¦") | |
| # ββ Fallback to Flickr30k βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if ds is None: | |
| raw = load_dataset(DATASET_ID, split="test", trust_remote_code=True) | |
| ds = raw.shuffle(seed=seed).select(range(min(n, len(raw)))) | |
| print(f" β Flickr30k loaded ({len(ds)} images)") | |
| dataset = COCOValDataset(ds, processor, max_len=max_len) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=0, | |
| pin_memory=False, | |
| collate_fn=_collate_fn, | |
| ) | |
| print(f" Batches : {len(dataloader)}") | |
| print("=" * 60) | |
| return dataloader | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Standalone entrypoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| from step1_load_model import load_model | |
| _, processor, _ = load_model() | |
| loader = load_val_data(processor, n=500, batch_size=8) | |
| # Peek at first batch | |
| batch = next(iter(loader)) | |
| print(f"\nβ DataLoader ready!") | |
| print(f" pixel_values shape : {batch['pixel_values'].shape}") | |
| print(f" labels shape : {batch['labels'].shape}") | |
| print(f" Sample caption : {batch['captions'][0][:80]}") | |
| print(f"\nImport in notebooks:") | |
| print(" from task.task_03.step2_prepare_data import load_val_data") | |