Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader, Dataset, Subset | |
| from torchvision import transforms | |
| from transformers import GPT2TokenizerFast | |
| from .config import PathsConfig, TrainingConfig | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def train_image_transform() -> transforms.Compose: | |
| """ | |
| Image preprocessing for training with random augmentation to improve | |
| generalization. Augmentations are kept moderate to avoid changing the | |
| semantic content of the scene. | |
| """ | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ColorJitter( | |
| brightness=0.2, | |
| contrast=0.2, | |
| saturation=0.2, | |
| hue=0.05, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ] | |
| ) | |
| def eval_image_transform() -> transforms.Compose: | |
| """ | |
| Deterministic preprocessing for validation and test: resize, center-crop | |
| to 224x224, normalize. | |
| """ | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ] | |
| ) | |
| class ImageCaptionDataset(Dataset): | |
| """ | |
| Custom Dataset for the visually impaired image captioning data. | |
| This implementation is tailored to your existing layout: | |
| - Images: <data_root>/visual_dataset/*.jpg | |
| - Text: | |
| - visual.token.txt (image#idx<TAB>caption) | |
| - visual.trainImages.txt (one image filename per line) | |
| - visual.testImages.txt (one image filename per line) | |
| """ | |
| def __init__( | |
| self, | |
| paths_cfg: PathsConfig, | |
| tokenizer: GPT2TokenizerFast, | |
| split: str = "train", | |
| training_cfg: Optional[TrainingConfig] = None, | |
| transform: Optional[transforms.Compose] = None, | |
| random_caption: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| if split not in {"train", "val", "test"}: | |
| raise ValueError("split must be one of {'train', 'val', 'test'}") | |
| self.paths_cfg = paths_cfg | |
| self.tokenizer = tokenizer | |
| self.training_cfg = training_cfg or TrainingConfig() | |
| # If no transform is provided, fall back to a deterministic eval | |
| # transform so this class can still be used directly. In practice, | |
| # create_dataloader() will supply train/eval-specific transforms. | |
| self.transform = transform or eval_image_transform() | |
| self.random_caption = random_caption | |
| self.max_length: int = int(self.training_cfg.max_caption_length) | |
| # Load all captions from visual.token.txt | |
| token_path = self.paths_cfg.token_file | |
| if not os.path.exists(token_path): | |
| raise FileNotFoundError(f"Caption file not found: {token_path}") | |
| self.captions_by_image: Dict[str, List[str]] = {} | |
| with open(token_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| key, caption = line.split("\t", 1) | |
| except ValueError as exc: | |
| raise ValueError(f"Malformed line in {token_path}: {line}") from exc | |
| img_name = key.split("#")[0] | |
| self.captions_by_image.setdefault(img_name, []).append(caption.strip()) | |
| # Choose image list file based on split | |
| if split == "train": | |
| list_file = self.paths_cfg.train_list_file | |
| else: | |
| # We only have a single test list in this dataset; use it for both | |
| # 'val' and 'test' splits for now. | |
| list_file = self.paths_cfg.test_list_file | |
| if not os.path.exists(list_file): | |
| raise FileNotFoundError(f"Image list file for split '{split}' not found: {list_file}") | |
| self.image_ids: List[str] = [] | |
| with open(list_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| img_name = line.strip() | |
| if not img_name: | |
| continue | |
| if img_name not in self.captions_by_image: | |
| # Skip images without captions to avoid runtime issues | |
| continue | |
| self.image_ids.append(img_name) | |
| if not self.image_ids: | |
| raise RuntimeError(f"No images with captions found for split '{split}'.") | |
| print(f"Loaded {len(self.image_ids)} {split} images with captions.") | |
| def __len__(self) -> int: | |
| return len(self.image_ids) | |
| def __getitem__(self, idx: int) -> Dict[str, Tensor]: | |
| img_name = self.image_ids[idx] | |
| img_path = os.path.join(self.paths_cfg.images_dir, img_name) | |
| if not os.path.exists(img_path): | |
| raise FileNotFoundError(f"Image file not found: {img_path}") | |
| image = Image.open(img_path).convert("RGB") | |
| image_tensor = self.transform(image) | |
| caption_list = self.captions_by_image[img_name] | |
| if not caption_list: | |
| raise RuntimeError(f"No captions available for image {img_name}") | |
| # Choose a caption. During training we consider up to three different | |
| # captions per image and randomly sample among them; for evaluation we | |
| # always take the first caption. We only strip leading/trailing | |
| # whitespace so that the raw textual content is preserved and no | |
| # characters are dropped before tokenization. | |
| if self.random_caption: | |
| limited_captions = caption_list[:3] | |
| caption = random.choice(limited_captions) | |
| else: | |
| caption = caption_list[0] | |
| caption = caption.strip() | |
| # Convert caption text into token IDs without adding any extra special | |
| # tokens so we retain a direct mapping between the raw caption string | |
| # and the token sequence. | |
| token_ids: List[int] = self.tokenizer.encode( | |
| caption, | |
| add_special_tokens=False, | |
| ) | |
| # Define explicit BOS (start-of-sentence) and EOS (end-of-sentence) | |
| # tokens so the model learns where captions begin and end. If the | |
| # tokenizer does not define a BOS token, we reuse EOS. | |
| bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id | |
| eos_token_id = self.tokenizer.eos_token_id | |
| seq_ids: List[int] = [bos_token_id] + token_ids + [eos_token_id] | |
| # Truncate if necessary to respect max_length. To guarantee that the | |
| # full caption (including BOS/EOS) can be represented without cutting | |
| # tokens, ensure that training_cfg.max_caption_length is set large | |
| # enough for your data. | |
| if len(seq_ids) > self.max_length: | |
| seq_ids = seq_ids[: self.max_length] | |
| # Pad up to max_length with pad_token_id and build attention mask. | |
| pad_id = self.tokenizer.pad_token_id | |
| input_ids = torch.full( | |
| (self.max_length,), | |
| pad_id, | |
| dtype=torch.long, | |
| ) | |
| attention_mask = torch.zeros(self.max_length, dtype=torch.long) | |
| seq_len = len(seq_ids) | |
| input_ids[:seq_len] = torch.tensor(seq_ids, dtype=torch.long) | |
| attention_mask[:seq_len] = 1 | |
| # Labels are initially the same as input_ids; padding positions will | |
| # be set to -100 so they are ignored by the loss. | |
| labels = input_ids.clone() | |
| labels[attention_mask == 0] = -100 | |
| return { | |
| "image": image_tensor, | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels, | |
| "caption": caption, | |
| "image_id": img_name, | |
| } | |
| def create_tokenizer() -> GPT2TokenizerFast: | |
| """ | |
| Create a GPT-2 tokenizer with a defined pad token. | |
| """ | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| def _infer_category_from_filename(filename: str) -> str: | |
| """ | |
| Infer a coarse category label from an image filename. | |
| Heuristic: | |
| - Strip directory and extension. | |
| - Remove trailing digits to group files like 'bench1.jpg', 'bench25.jpg' | |
| into the same category 'bench'. | |
| """ | |
| base = os.path.basename(filename) | |
| stem, _ext = os.path.splitext(base) | |
| # Remove trailing digits | |
| i = len(stem) | |
| while i > 0 and stem[i - 1].isdigit(): | |
| i -= 1 | |
| category = stem[:i] or stem | |
| return category | |
| def _balanced_train_val_indices( | |
| dataset: ImageCaptionDataset, | |
| val_ratio: float = 0.2, | |
| ) -> Tuple[List[int], List[int]]: | |
| """ | |
| Split the dataset indices into train and validation sets. | |
| The validation set: | |
| - Targets approximately `val_ratio` of the total dataset size. | |
| - Is balanced across categories inferred from filenames, i.e., each | |
| category contributes (as much as possible) the same number of images. | |
| """ | |
| num_items = len(dataset.image_ids) | |
| if num_items == 0: | |
| raise RuntimeError("Cannot create train/val split from an empty dataset.") | |
| # Group indices by inferred category | |
| category_to_indices: Dict[str, List[int]] = {} | |
| for idx, img_name in enumerate(dataset.image_ids): | |
| cat = _infer_category_from_filename(img_name) | |
| category_to_indices.setdefault(cat, []).append(idx) | |
| # Sort indices within each category for deterministic behavior | |
| for indices in category_to_indices.values(): | |
| indices.sort() | |
| categories = sorted(category_to_indices.keys()) | |
| num_categories = len(categories) | |
| # Desired total size for validation set | |
| target_val_size = max(1, int(round(val_ratio * num_items))) | |
| # Base number of validation samples per category, constrained by the | |
| # smallest category so we can keep counts balanced. | |
| min_cat_size = min(len(category_to_indices[cat]) for cat in categories) | |
| per_category = min( | |
| min_cat_size, | |
| max(1, int(round(target_val_size / max(1, num_categories)))), | |
| ) | |
| val_indices: List[int] = [] | |
| train_indices: List[int] = [] | |
| for cat in categories: | |
| indices = category_to_indices[cat] | |
| val_for_cat = indices[:per_category] | |
| train_for_cat = indices[per_category:] | |
| val_indices.extend(val_for_cat) | |
| train_indices.extend(train_for_cat) | |
| return train_indices, val_indices | |
| def create_dataloader( | |
| paths_cfg: PathsConfig, | |
| training_cfg: TrainingConfig, | |
| split: str, | |
| tokenizer: Optional[GPT2TokenizerFast] = None, | |
| shuffle: Optional[bool] = None, | |
| ) -> Tuple[DataLoader, GPT2TokenizerFast]: | |
| """ | |
| Factory function to create a DataLoader for a given split. | |
| Parameters | |
| ---------- | |
| paths_cfg: | |
| Paths configuration. | |
| training_cfg: | |
| Training configuration containing batch size, max caption length, etc. | |
| split: | |
| One of {'train', 'val', 'test'}. | |
| tokenizer: | |
| Optional pre-initialized GPT-2 tokenizer. If None, a new one is created. | |
| shuffle: | |
| Optional flag to override shuffle behavior. If None, shuffle is True | |
| for the 'train' split and False otherwise. | |
| """ | |
| if tokenizer is None: | |
| tokenizer = create_tokenizer() | |
| if shuffle is None: | |
| shuffle = split == "train" | |
| # For training and validation, we build a single underlying dataset from | |
| # the training list file and then create a balanced 80/20 split by | |
| # category. The test split continues to use the dedicated test list file. | |
| if split == "test": | |
| random_caption = False | |
| dataset = ImageCaptionDataset( | |
| paths_cfg=paths_cfg, | |
| tokenizer=tokenizer, | |
| split="test", | |
| training_cfg=training_cfg, | |
| transform=eval_image_transform(), | |
| random_caption=random_caption, | |
| ) | |
| else: | |
| # Underlying full training dataset | |
| full_train_dataset = ImageCaptionDataset( | |
| paths_cfg=paths_cfg, | |
| tokenizer=tokenizer, | |
| split="train", | |
| training_cfg=training_cfg, | |
| transform=train_image_transform(), | |
| random_caption=True, # always randomize captions during training | |
| ) | |
| train_indices, val_indices = _balanced_train_val_indices( | |
| full_train_dataset, | |
| val_ratio=0.2, | |
| ) | |
| if split == "train": | |
| dataset = Subset(full_train_dataset, train_indices) | |
| elif split == "val": | |
| dataset = Subset(full_train_dataset, val_indices) | |
| else: | |
| raise ValueError("split must be one of {'train', 'val', 'test'}") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=training_cfg.batch_size, | |
| shuffle=shuffle, | |
| num_workers=training_cfg.num_workers, | |
| pin_memory=torch.cuda.is_available(), | |
| ) | |
| return dataloader, tokenizer | |