import os.path as op from zipfile import ZipFile, BadZipFile import torch.utils.data as data from PIL import Image from io import BytesIO import multiprocessing _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] class ZipData(data.Dataset): _IGNORE_ATTRS = {'_zip_file'} def __init__(self, path, map_file, transform=None, target_transform=None, extensions=None): self._path = path if not extensions: extensions = _VALID_IMAGE_TYPES self._zip_file = ZipFile(path) self.zip_dict = {} self.samples = [] self.transform = transform self.target_transform = target_transform self.class_to_idx = {} with open(map_file, 'r') as f: for line in iter(f.readline, ""): line = line.strip() if not line: continue cls_idx = [l for l in line.split('\t') if l] if not cls_idx: continue if (len(cls_idx) < 2): cls_idx = [l for l in line.split(' ') if l] if not cls_idx: continue assert len(cls_idx) >= 2, "invalid line: {}".format(line) idx = int(cls_idx[1]) cls = cls_idx[0] del cls_idx at_idx = cls.find('@') assert at_idx >= 0, "invalid class: {}".format(cls) cls = cls[at_idx + 1:] if cls.startswith('/'): # Python ZipFile expects no root cls = cls[1:] assert cls, "invalid class in line {}".format(line) prev_idx = self.class_to_idx.get(cls) assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format( cls, idx, prev_idx ) self.class_to_idx[cls] = idx for fst in self._zip_file.infolist(): fname = fst.filename target = self.class_to_idx.get(fname) if target is None: continue if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: continue ext = op.splitext(fname)[1].lower() if ext in extensions: self.samples.append((fname, target)) assert len(self), "No images found in: {} with map: {}".format(self._path, map_file) def __repr__(self): return 'ZipData({}, size={})'.format(self._path, len(self)) def __getstate__(self): return { key: val if key not in self._IGNORE_ATTRS else None for key, val in self.__dict__.iteritems() } def __getitem__(self, index): proc = multiprocessing.current_process() pid = proc.pid # get pid of this process. if pid not in self.zip_dict: self.zip_dict[pid] = ZipFile(self._path) zip_file = self.zip_dict[pid] if index >= len(self) or index < 0: raise KeyError("{} is invalid".format(index)) path, target = self.samples[index] try: sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB') except BadZipFile: print("bad zip file") return None, None if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples)