|
import os.path as op |
|
from zipfile import ZipFile, BadZipFile |
|
import torch.utils.data as data |
|
from PIL import Image |
|
from io import BytesIO |
|
import multiprocessing |
|
|
|
_VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] |
|
|
|
|
|
class ZipData(data.Dataset): |
|
_IGNORE_ATTRS = {'_zip_file'} |
|
|
|
def __init__(self, path, map_file, |
|
transform=None, target_transform=None, |
|
extensions=None): |
|
self._path = path |
|
if not extensions: |
|
extensions = _VALID_IMAGE_TYPES |
|
self._zip_file = ZipFile(path) |
|
self.zip_dict = {} |
|
self.samples = [] |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self.class_to_idx = {} |
|
with open(map_file, 'r') as f: |
|
for line in iter(f.readline, ""): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
cls_idx = [l for l in line.split('\t') if l] |
|
if not cls_idx: |
|
continue |
|
if (len(cls_idx) < 2): |
|
cls_idx = [l for l in line.split(' ') if l] |
|
if not cls_idx: |
|
continue |
|
assert len(cls_idx) >= 2, "invalid line: {}".format(line) |
|
idx = int(cls_idx[1]) |
|
cls = cls_idx[0] |
|
del cls_idx |
|
at_idx = cls.find('@') |
|
assert at_idx >= 0, "invalid class: {}".format(cls) |
|
cls = cls[at_idx + 1:] |
|
if cls.startswith('/'): |
|
|
|
cls = cls[1:] |
|
assert cls, "invalid class in line {}".format(line) |
|
prev_idx = self.class_to_idx.get(cls) |
|
assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format( |
|
cls, idx, prev_idx |
|
) |
|
self.class_to_idx[cls] = idx |
|
|
|
for fst in self._zip_file.infolist(): |
|
fname = fst.filename |
|
target = self.class_to_idx.get(fname) |
|
if target is None: |
|
continue |
|
if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: |
|
continue |
|
ext = op.splitext(fname)[1].lower() |
|
if ext in extensions: |
|
self.samples.append((fname, target)) |
|
assert len(self), "No images found in: {} with map: {}".format(self._path, map_file) |
|
|
|
def __repr__(self): |
|
return 'ZipData({}, size={})'.format(self._path, len(self)) |
|
|
|
def __getstate__(self): |
|
return { |
|
key: val if key not in self._IGNORE_ATTRS else None |
|
for key, val in self.__dict__.iteritems() |
|
} |
|
|
|
def __getitem__(self, index): |
|
proc = multiprocessing.current_process() |
|
pid = proc.pid |
|
if pid not in self.zip_dict: |
|
self.zip_dict[pid] = ZipFile(self._path) |
|
zip_file = self.zip_dict[pid] |
|
|
|
if index >= len(self) or index < 0: |
|
raise KeyError("{} is invalid".format(index)) |
|
path, target = self.samples[index] |
|
try: |
|
sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB') |
|
except BadZipFile: |
|
print("bad zip file") |
|
return None, None |
|
if self.transform is not None: |
|
sample = self.transform(sample) |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
return sample, target |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|