import os import time import torch as th import numpy as np import torchvision.datasets as dset import torchvision.transforms as transforms import imageio import ttools import rendering BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) DATA = os.path.join(BASE_DIR, "data") LOG = ttools.get_logger(__name__) class QuickDrawImageDataset(th.utils.data.Dataset): BASE_DATA_URL = \ "https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy" """ Args: spatial_limit(int): maximum spatial extent in pixels. """ def __init__(self, imsize, train=True): super(QuickDrawImageDataset, self).__init__() file = os.path.join(DATA, "cat.npy") self.imsize = imsize if not os.path.exists(file): msg = "Dataset file %s does not exist, please download" " it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL) LOG.error(msg) raise RuntimeError(msg) self.data = np.load(file, allow_pickle=True, encoding="latin1") def __len__(self): return self.data.shape[0] def __getitem__(self, idx): im = np.reshape(self.data[idx], (1, 1, 28, 28)) im = th.from_numpy(im).float() / 255.0 im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize)) # Bring it to [-1, 1] im = th.clamp(im, 0, 1) im -= 0.5 im /= 0.5 return im.squeeze(0) class QuickDrawDataset(th.utils.data.Dataset): BASE_DATA_URL = \ "https://storage.cloud.google.com/quickdraw_dataset/sketchrnn" """ Args: spatial_limit(int): maximum spatial extent in pixels. """ def __init__(self, dataset, mode="train", max_seq_length=250, spatial_limit=1000): super(QuickDrawDataset, self).__init__() file = os.path.join(DATA, "sketchrnn_"+dataset) remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset) self.max_seq_length = max_seq_length self.spatial_limit = spatial_limit if mode not in ["train", "test", "valid"]: return ValueError("Only allowed data mode are 'train' and 'test'," " 'valid'.") if not os.path.exists(file): msg = "Dataset file %s does not exist, please download" " it from %s" % (file, remote) LOG.error(msg) raise RuntimeError(msg) data = np.load(file, allow_pickle=True, encoding="latin1")[mode] data = self.purify(data) data = self.normalize(data) # Length of longest sequence in the dataset self.nmax = max([len(seq) for seq in data]) self.sketches = data def __repr__(self): return "Dataset with %d sequences of max length %d" % \ (len(self.sketches), self.nmax) def __len__(self): return len(self.sketches) def __getitem__(self, idx): """Return the idx-th stroke in 5-D format, padded to length (Nmax+2). The first and last element of the sequence are fixed to "start-" and "end-of-sequence" token. dx, dy, + 3 numbers for one-hot encoding of state: 1 0 0: pen touching paper till next point 0 1 0: pen lifted from paper after current point 0 0 1: drawing has ended, next points (including current will not be drawn) """ sample_data = self.sketches[idx] # Allow two extra slots for start/end of sequence tokens sample = np.zeros((self.nmax+2, 5), dtype=np.float32) n = sample_data.shape[0] # normalize dx, dy deltas = sample_data[:, :2] # Absolute coordinates positions = deltas[..., :2].cumsum(0) maxi = np.abs(positions).max() + 1e-8 deltas = deltas / (1.1 * maxi) # leave some margin on edges # fill in dx, dy coordinates sample[1:n+1, :2] = deltas # on paper indicator: 0 means touching paper in the 3d format, flip it sample[1:n+1, 2] = 1 - sample_data[:, 2] # off-paper indicator, complement of previous flag sample[1:n+1, 3] = 1 - sample[1:n+1, 2] # fill with end of sequence tokens for the remainder sample[n+1:, 4] = 1 # Start of sequence token sample[0] = [0, 0, 1, 0, 0] return sample def purify(self, strokes): """removes to small or too long sequences + removes large gaps""" data = [] for seq in strokes: if seq.shape[0] <= self.max_seq_length: # and seq.shape[0] > 10: # Limit large spatial gaps seq = np.minimum(seq, self.spatial_limit) seq = np.maximum(seq, -self.spatial_limit) seq = np.array(seq, dtype=np.float32) data.append(seq) return data def calculate_normalizing_scale_factor(self, strokes): """Calculate the normalizing factor explained in appendix of sketch-rnn.""" data = [] for i, stroke_i in enumerate(strokes): for j, pt in enumerate(strokes[i]): data.append(pt[0]) data.append(pt[1]) data = np.array(data) return np.std(data) def normalize(self, strokes): """Normalize entire dataset (delta_x, delta_y) by the scaling factor.""" data = [] scale_factor = self.calculate_normalizing_scale_factor(strokes) for seq in strokes: seq[:, 0:2] /= scale_factor data.append(seq) return data class FixedLengthQuickDrawDataset(QuickDrawDataset): """A variant of the QuickDraw dataset where the strokes are represented as a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1. """ def __init__(self, *args, canvas_size=64, **kwargs): super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs) self.canvas_size = canvas_size def __getitem__(self, idx): sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx) # We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged strokes = sample[:, :3] im = np.zeros((1, 1)) # render image # start = time.time() im = rendering.opacityStroke2diffvg( th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size, relative=True, debug=False) im = im.squeeze(0).numpy() # elapsed = (time.time() - start)*1000 # print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed)) return strokes, im class MNISTDataset(th.utils.data.Dataset): def __init__(self, imsize, train=True): super(MNISTDataset, self).__init__() self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"), train=train, download=True, transform=transforms.Compose([ transforms.Resize((imsize, imsize)), transforms.ToTensor(), ])) def __len__(self): return len(self.mnist) def __getitem__(self, idx): im, label = self.mnist[idx] # make sure data uses [0, 1] range im -= im.min() im /= im.max() + 1e-8 # Bring it to [-1, 1] im -= 0.5 im /= 0.5 return im