FROMAGe / fromage /data.py
alvanli
Add cheese model
1f43fd8
raw
history blame
No virus
4.75 kB
"""Modified from https://github.com/mlfoundations/open_clip"""
from typing import Optional, Tuple
import collections
import logging
import os
import numpy as np
import pandas as pd
import torch
import torchvision.datasets as datasets
from torchvision import transforms as T
from PIL import Image, ImageFont
from torch.utils.data import Dataset
from fromage import utils
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
def get_dataset(args, split: str, tokenizer, precision: str = 'fp32') -> Dataset:
assert split in ['train', 'val'
], 'Expected split to be one of "train" or "val", got {split} instead.'
dataset_paths = []
image_data_dirs = []
train = split == 'train'
# Default configs for datasets.
# Folder structure should look like:
if split == 'train':
if 'cc3m' in args.dataset:
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_train.tsv'))
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/training/'))
else:
raise NotImplementedError
elif split == 'val':
if 'cc3m' in args.val_dataset:
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_val.tsv'))
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/validation'))
else:
raise NotImplementedError
assert len(dataset_paths) == len(image_data_dirs) == 1, (dataset_paths, image_data_dirs)
else:
raise NotImplementedError
if len(dataset_paths) > 1:
print(f'{len(dataset_paths)} datasets requested: {dataset_paths}')
dataset = torch.utils.data.ConcatDataset([
CsvDataset(path, image_dir, tokenizer, 'image',
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
for (path, image_dir) in zip(dataset_paths, image_data_dirs)])
elif len(dataset_paths) == 1:
dataset = CsvDataset(dataset_paths[0], image_data_dirs[0], tokenizer, 'image',
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
else:
raise ValueError(f'There should be at least one valid dataset, got train={args.dataset}, val={args.val_dataset} instead.')
return dataset
class CsvDataset(Dataset):
def __init__(self, input_filename, base_image_dir, tokenizer, img_key,
caption_key, feature_extractor_model: str,
train: bool = True, max_len: int = 32, sep="\t", precision: str = 'fp32',
image_size: int = 224, retrieval_token_idx: int = -1):
logging.debug(f'Loading tsv data from {input_filename}.')
df = pd.read_csv(input_filename, sep=sep)
self.base_image_dir = base_image_dir
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
assert len(self.images) == len(self.captions)
self.feature_extractor_model = feature_extractor_model
self.feature_extractor = utils.get_feature_extractor_for_model(
feature_extractor_model, image_size=image_size, train=False)
self.image_size = image_size
self.tokenizer = tokenizer
self.max_len = max_len
self.precision = precision
self.retrieval_token_idx = retrieval_token_idx
self.font = None
logging.debug('Done loading data.')
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
while True:
image_path = os.path.join(self.base_image_dir, str(self.images[idx]))
caption = str(self.captions[idx])
try:
img = Image.open(image_path)
images = utils.get_pixel_values_for_model(self.feature_extractor, img)
caption += '[RET]'
tokenized_data = self.tokenizer(
caption,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=self.max_len)
tokens = tokenized_data.input_ids[0]
caption_len = tokenized_data.attention_mask[0].sum()
decode_caption = self.tokenizer.decode(tokens, skip_special_tokens=False)
self.font = self.font or ImageFont.load_default()
cap_img = utils.create_image_of_text(decode_caption.encode('ascii', 'ignore'), width=self.image_size, nrows=2, font=self.font)
if tokens[-1] not in [self.retrieval_token_idx, self.tokenizer.pad_token_id]:
tokens[-1] = self.retrieval_token_idx
return image_path, images, cap_img, tokens, caption_len
except Exception as e:
print(f'Error reading {image_path} with caption {caption}: {e}')
# Pick a new example at random.
idx = np.random.randint(0, len(self)-1)