|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import zipfile |
|
import PIL.Image |
|
import json |
|
import torch |
|
import dnnlib |
|
import h5py as h5 |
|
|
|
try: |
|
import pyspng |
|
except ImportError: |
|
pyspng = None |
|
|
|
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
name, |
|
raw_shape, |
|
max_size=None, |
|
load_labels=False, |
|
xflip=False, |
|
random_seed=0, |
|
**kwargs, |
|
): |
|
self._name = name |
|
self._raw_shape = list(raw_shape) |
|
self._load_labels = load_labels |
|
self._raw_labels = None |
|
self._label_shape = None |
|
|
|
|
|
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) |
|
if (max_size is not None) and (self._raw_idx.size > max_size): |
|
np.random.RandomState(random_seed).shuffle(self._raw_idx) |
|
self._raw_idx = np.sort(self._raw_idx[:max_size]) |
|
|
|
|
|
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) |
|
if xflip: |
|
self._raw_idx = np.tile(self._raw_idx, 2) |
|
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) |
|
|
|
def _get_raw_labels(self, idx): |
|
if self._raw_labels is None: |
|
self._raw_labels = self._load_raw_labels(idx) if self._load_labels else None |
|
if self._raw_labels is None: |
|
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) |
|
|
|
|
|
assert self._raw_labels.dtype in [np.float32, np.int64] |
|
if self._raw_labels.dtype == np.int64: |
|
|
|
assert np.all(self._raw_labels >= 0) |
|
return self._raw_labels |
|
|
|
def close(self): |
|
pass |
|
|
|
def _load_raw_image(self, raw_idx): |
|
raise NotImplementedError |
|
|
|
def _load_raw_labels(self, raw_idx): |
|
raise NotImplementedError |
|
|
|
def __getstate__(self): |
|
return dict(self.__dict__, _raw_labels=None) |
|
|
|
def __del__(self): |
|
try: |
|
self.close() |
|
except: |
|
pass |
|
|
|
def __len__(self): |
|
return self._raw_idx.size |
|
|
|
def __getitem__(self, idx): |
|
image = self._load_raw_image(self._raw_idx[idx]) |
|
assert isinstance(image, np.ndarray) |
|
assert list(image.shape) == self.image_shape |
|
assert image.dtype == np.uint8 |
|
if self._xflip[idx]: |
|
assert image.ndim == 3 |
|
image = image[:, :, ::-1] |
|
return image.copy(), self.get_label(idx) |
|
|
|
def get_label(self, idx): |
|
label = self._get_raw_labels(self._raw_idx[idx]) |
|
if label.dtype == np.int64: |
|
onehot = np.zeros(self.label_shape, dtype=np.float32) |
|
onehot[label] = 1 |
|
label = onehot |
|
return label.copy() |
|
|
|
def get_details(self, idx): |
|
d = dnnlib.EasyDict() |
|
d.raw_idx = int(self._raw_idx[idx]) |
|
d.xflip = int(self._xflip[idx]) != 0 |
|
d.raw_label = self._get_raw_labels()[d.raw_idx].copy() |
|
return d |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def image_shape(self): |
|
return list(self._raw_shape[1:]) |
|
|
|
@property |
|
def num_channels(self): |
|
assert len(self.image_shape) == 3 |
|
return self.image_shape[0] |
|
|
|
@property |
|
def resolution(self): |
|
assert len(self.image_shape) == 3 |
|
assert self.image_shape[1] == self.image_shape[2] |
|
return self.image_shape[1] |
|
|
|
@property |
|
def label_shape(self): |
|
if self._label_shape is None: |
|
raw_labels = self._get_raw_labels(0) |
|
if raw_labels.dtype == np.int64: |
|
self._label_shape = [int(np.max(raw_labels)) + 1] |
|
else: |
|
self._label_shape = raw_labels.shape[1:] |
|
return list(self._label_shape) |
|
|
|
@property |
|
def label_dim(self): |
|
assert len(self.label_shape) == 1 |
|
return self.label_shape[0] |
|
|
|
@property |
|
def has_labels(self): |
|
return any(x != 0 for x in self.label_shape) |
|
|
|
@property |
|
def has_onehot_labels(self): |
|
return self._get_raw_labels().dtype == np.int64 |
|
|
|
|
|
|
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
def __init__( |
|
self, |
|
root, |
|
resolution=None, |
|
**super_kwargs, |
|
): |
|
self._path = root |
|
self._zipfile = None |
|
|
|
if os.path.isdir(self._path): |
|
self._type = "dir" |
|
self._all_fnames = { |
|
os.path.relpath(os.path.join(root, fname), start=self._path) |
|
for root, _dirs, files in os.walk(self._path) |
|
for fname in files |
|
} |
|
elif self._file_ext(self._path) == ".zip": |
|
self._type = "zip" |
|
self._all_fnames = set(self._get_zipfile().namelist()) |
|
elif self._file_ext(self._path) == ".hdf5": |
|
self._type = "hdf5" |
|
else: |
|
raise IOError("Path must point to a directory or zip") |
|
|
|
PIL.Image.init() |
|
if self._type in ["dir", "zip"]: |
|
self._image_fnames = sorted( |
|
fname |
|
for fname in self._all_fnames |
|
if self._file_ext(fname) in PIL.Image.EXTENSION |
|
) |
|
if len(self._image_fnames) == 0: |
|
raise IOError("No image files found in the specified path") |
|
|
|
name = os.path.splitext(os.path.basename(self._path))[0] |
|
if self._type == "hdf5": |
|
with h5.File(self._path, "r") as f: |
|
nb = len(f["imgs"]) |
|
sze = list(f["imgs"][0].shape) |
|
raw_shape = [nb] + sze |
|
else: |
|
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) |
|
if resolution is not None and ( |
|
raw_shape[2] != resolution or raw_shape[3] != resolution |
|
): |
|
raise IOError("Image files do not match the specified resolution") |
|
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) |
|
|
|
@staticmethod |
|
def _file_ext(fname): |
|
return os.path.splitext(fname)[1].lower() |
|
|
|
def _get_zipfile(self): |
|
assert self._type == "zip" |
|
if self._zipfile is None: |
|
self._zipfile = zipfile.ZipFile(self._path) |
|
return self._zipfile |
|
|
|
def _open_file(self, fname): |
|
if self._type == "dir": |
|
return open(os.path.join(self._path, fname), "rb") |
|
if self._type == "zip": |
|
return self._get_zipfile().open(fname, "r") |
|
return None |
|
|
|
def close(self): |
|
try: |
|
if self._zipfile is not None: |
|
self._zipfile.close() |
|
finally: |
|
self._zipfile = None |
|
|
|
def __getstate__(self): |
|
return dict(super().__getstate__(), _zipfile=None) |
|
|
|
def _load_raw_image(self, raw_idx): |
|
if self._type in ["dir", "zip"]: |
|
fname = self._image_fnames[raw_idx] |
|
with self._open_file(fname) as f: |
|
if pyspng is not None and self._file_ext(fname) == ".png": |
|
image = pyspng.load(f.read()) |
|
else: |
|
image = np.array(PIL.Image.open(f)) |
|
if image.ndim == 2: |
|
image = image[:, :, np.newaxis] |
|
image = image.transpose(2, 0, 1) |
|
elif self._type == "hdf5": |
|
with h5.File(self._path, "r") as f: |
|
image = f["imgs"][raw_idx] |
|
return image |
|
|
|
def _load_raw_labels(self, idx): |
|
if self._type in ["dir", "zip"]: |
|
fname = "dataset.json" |
|
if fname not in self._all_fnames: |
|
return None |
|
with self._open_file(fname) as f: |
|
labels = json.load(f)["labels"] |
|
if labels is None: |
|
return None |
|
labels = dict(labels) |
|
labels = [labels[fname.replace("\\", "/")] for fname in self._image_fnames] |
|
labels = np.array(labels) |
|
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])[idx] |
|
|
|
elif self._type == "hdf5": |
|
with h5.File(self._path, "r") as f: |
|
labels = f["labels"][idx] |
|
return labels |
|
|
|
|
|
|
|
|