test / tools /imgcap_dataset.py
Tu Bui
first commit
6142a25
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Dataset class for image-caption
@author: Tu Bui @University of Surrey
"""
import json
from PIL import Image
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from functools import partial
import pytorch_lightning as pl
from ldm.util import instantiate_from_config
import pandas as pd
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
shuffle_val_dataloader=False):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True,
worker_init_fn=init_fn)
def _val_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _test_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
def _predict_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn)
class ImageCaptionRaw(Dataset):
def __init__(self, image_dir, caption_file, secret_len=100, transform=None):
super().__init__()
self.image_dir = Path(image_dir)
self.data = []
with open(caption_file, 'rt') as f:
for line in f:
self.data.append(json.loads(line))
self.secret_len = secret_len
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image = Image.open(self.image_dir/item['image']).convert('RGB').resize((512,512))
caption = item['captions']
cid = torch.randint(0, len(caption), (1,)).item()
caption = caption[cid]
if self.transform is not None:
image = self.transform(image)
image = np.array(image, dtype=np.float32)/ 255.0 # normalize to [0, 1]
target = image * 2.0 - 1.0 # normalize to [-1, 1]
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
return dict(image=image, caption=caption, target=target, secret=secret)
class BAMFG(Dataset):
def __init__(self, style_dir, gt_dir, data_list, transform=None):
super().__init__()
self.style_dir = Path(style_dir)
self.gt_dir = Path(gt_dir)
self.data = pd.read_csv(data_list)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data.iloc[idx]
gt_img = Image.open(self.gt_dir/item['gt_img']).convert('RGB').resize((512,512))
style_img = Image.open(self.style_dir/item['style_img']).convert('RGB').resize((512,512))
txt = item['prompt']
if self.transform is not None:
gt_img = self.transform(gt_img)
style_img = self.transform(style_img)
gt_img = np.array(gt_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
style_img = np.array(style_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
target = gt_img * 2.0 - 1.0 # normalize to [-1, 1]
return dict(image=gt_img, txt=txt, hint=style_img)