File size: 6,062 Bytes
6142a25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Dataset class for image-caption
@author: Tu Bui @University of Surrey
"""
import json
from PIL import Image
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from functools import partial
import pytorch_lightning as pl
from ldm.util import instantiate_from_config
import pandas as pd
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
shuffle_val_dataloader=False):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True,
worker_init_fn=init_fn)
def _val_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _test_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
def _predict_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn)
class ImageCaptionRaw(Dataset):
def __init__(self, image_dir, caption_file, secret_len=100, transform=None):
super().__init__()
self.image_dir = Path(image_dir)
self.data = []
with open(caption_file, 'rt') as f:
for line in f:
self.data.append(json.loads(line))
self.secret_len = secret_len
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image = Image.open(self.image_dir/item['image']).convert('RGB').resize((512,512))
caption = item['captions']
cid = torch.randint(0, len(caption), (1,)).item()
caption = caption[cid]
if self.transform is not None:
image = self.transform(image)
image = np.array(image, dtype=np.float32)/ 255.0 # normalize to [0, 1]
target = image * 2.0 - 1.0 # normalize to [-1, 1]
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
return dict(image=image, caption=caption, target=target, secret=secret)
class BAMFG(Dataset):
def __init__(self, style_dir, gt_dir, data_list, transform=None):
super().__init__()
self.style_dir = Path(style_dir)
self.gt_dir = Path(gt_dir)
self.data = pd.read_csv(data_list)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data.iloc[idx]
gt_img = Image.open(self.gt_dir/item['gt_img']).convert('RGB').resize((512,512))
style_img = Image.open(self.style_dir/item['style_img']).convert('RGB').resize((512,512))
txt = item['prompt']
if self.transform is not None:
gt_img = self.transform(gt_img)
style_img = self.transform(style_img)
gt_img = np.array(gt_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
style_img = np.array(style_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
target = gt_img * 2.0 - 1.0 # normalize to [-1, 1]
return dict(image=gt_img, txt=txt, hint=style_img) |