File size: 4,753 Bytes
1f43fd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""Modified from https://github.com/mlfoundations/open_clip"""
from typing import Optional, Tuple
import collections
import logging
import os
import numpy as np
import pandas as pd
import torch
import torchvision.datasets as datasets
from torchvision import transforms as T
from PIL import Image, ImageFont
from torch.utils.data import Dataset
from fromage import utils
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
def get_dataset(args, split: str, tokenizer, precision: str = 'fp32') -> Dataset:
assert split in ['train', 'val'
], 'Expected split to be one of "train" or "val", got {split} instead.'
dataset_paths = []
image_data_dirs = []
train = split == 'train'
# Default configs for datasets.
# Folder structure should look like:
if split == 'train':
if 'cc3m' in args.dataset:
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_train.tsv'))
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/training/'))
else:
raise NotImplementedError
elif split == 'val':
if 'cc3m' in args.val_dataset:
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_val.tsv'))
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/validation'))
else:
raise NotImplementedError
assert len(dataset_paths) == len(image_data_dirs) == 1, (dataset_paths, image_data_dirs)
else:
raise NotImplementedError
if len(dataset_paths) > 1:
print(f'{len(dataset_paths)} datasets requested: {dataset_paths}')
dataset = torch.utils.data.ConcatDataset([
CsvDataset(path, image_dir, tokenizer, 'image',
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
for (path, image_dir) in zip(dataset_paths, image_data_dirs)])
elif len(dataset_paths) == 1:
dataset = CsvDataset(dataset_paths[0], image_data_dirs[0], tokenizer, 'image',
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
else:
raise ValueError(f'There should be at least one valid dataset, got train={args.dataset}, val={args.val_dataset} instead.')
return dataset
class CsvDataset(Dataset):
def __init__(self, input_filename, base_image_dir, tokenizer, img_key,
caption_key, feature_extractor_model: str,
train: bool = True, max_len: int = 32, sep="\t", precision: str = 'fp32',
image_size: int = 224, retrieval_token_idx: int = -1):
logging.debug(f'Loading tsv data from {input_filename}.')
df = pd.read_csv(input_filename, sep=sep)
self.base_image_dir = base_image_dir
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
assert len(self.images) == len(self.captions)
self.feature_extractor_model = feature_extractor_model
self.feature_extractor = utils.get_feature_extractor_for_model(
feature_extractor_model, image_size=image_size, train=False)
self.image_size = image_size
self.tokenizer = tokenizer
self.max_len = max_len
self.precision = precision
self.retrieval_token_idx = retrieval_token_idx
self.font = None
logging.debug('Done loading data.')
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
while True:
image_path = os.path.join(self.base_image_dir, str(self.images[idx]))
caption = str(self.captions[idx])
try:
img = Image.open(image_path)
images = utils.get_pixel_values_for_model(self.feature_extractor, img)
caption += '[RET]'
tokenized_data = self.tokenizer(
caption,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=self.max_len)
tokens = tokenized_data.input_ids[0]
caption_len = tokenized_data.attention_mask[0].sum()
decode_caption = self.tokenizer.decode(tokens, skip_special_tokens=False)
self.font = self.font or ImageFont.load_default()
cap_img = utils.create_image_of_text(decode_caption.encode('ascii', 'ignore'), width=self.image_size, nrows=2, font=self.font)
if tokens[-1] not in [self.retrieval_token_idx, self.tokenizer.pad_token_id]:
tokens[-1] = self.retrieval_token_idx
return image_path, images, cap_img, tokens, caption_len
except Exception as e:
print(f'Error reading {image_path} with caption {caption}: {e}')
# Pick a new example at random.
idx = np.random.randint(0, len(self)-1)
|