Spaces:
Running
Running
import os | |
import numbers | |
import torch | |
import mxnet as mx | |
from PIL import Image | |
from torch.utils import data | |
from torchvision import transforms | |
import numpy as np | |
import PIL.Image as Image | |
""" Original mxnet dataset | |
""" | |
class MXFaceDataset(data.Dataset): | |
def __init__(self, root_dir, crop_param=(0, 0, 112, 112)): | |
super(MXFaceDataset, self,).__init__() | |
self.transform = transforms.Compose([ | |
# transforms.ToPILImage(), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
self.root_dir = root_dir | |
self.crop_param = crop_param | |
path_imgrec = os.path.join(root_dir, 'train.rec') | |
path_imgidx = os.path.join(root_dir, 'train.idx') | |
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') | |
s = self.imgrec.read_idx(0) | |
header, _ = mx.recordio.unpack(s) | |
if header.flag > 0: | |
self.header0 = (int(header.label[0]), int(header.label[1])) | |
self.imgidx = np.array(range(1, int(header.label[0]))) | |
else: | |
self.imgidx = np.array(list(self.imgrec.keys)) | |
def __getitem__(self, index): | |
idx = self.imgidx[index] | |
s = self.imgrec.read_idx(idx) | |
header, img = mx.recordio.unpack(s) | |
label = header.label | |
if not isinstance(label, numbers.Number): | |
label = label[0] | |
label = torch.tensor(label, dtype=torch.long) | |
sample = mx.image.imdecode(img).asnumpy() | |
if self.transform is not None: | |
sample: Image = transforms.ToPILImage()(sample) | |
sample = sample.crop(self.crop_param) | |
sample = self.transform(sample) | |
return sample, label | |
def __len__(self): | |
return len(self.imgidx) | |
""" MXNet binary dataset reader. | |
Refer to https://github.com/deepinsight/insightface. | |
""" | |
import pickle | |
from typing import List | |
from mxnet import ndarray as nd | |
class ReadMXNet(object): | |
def __init__(self, val_targets, rec_prefix, image_size=(112, 112)): | |
self.ver_list: List[object] = [] | |
self.ver_name_list: List[str] = [] | |
self.rec_prefix = rec_prefix | |
self.val_targets = val_targets | |
def init_dataset(self, val_targets, data_dir, image_size): | |
for name in val_targets: | |
path = os.path.join(data_dir, name + ".bin") | |
if os.path.exists(path): | |
data_set = self.load_bin(path, image_size) | |
self.ver_list.append(data_set) | |
self.ver_name_list.append(name) | |
def load_bin(self, path, image_size): | |
try: | |
with open(path, 'rb') as f: | |
bins, issame_list = pickle.load(f) # py2 | |
except UnicodeDecodeError as e: | |
with open(path, 'rb') as f: | |
bins, issame_list = pickle.load(f, encoding='bytes') # py3 | |
data_list = [] | |
# for flip in [0, 1]: | |
# data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) | |
# data_list.append(data) | |
for idx in range(len(issame_list) * 2): | |
_bin = bins[idx] | |
img = mx.image.imdecode(_bin) | |
if img.shape[1] != image_size[0]: | |
img = mx.image.resize_short(img, image_size[0]) | |
img = nd.transpose(img, axes=(2, 0, 1)) # (C, H, W) | |
img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C) | |
import PIL.Image as Image | |
fig = Image.fromarray(img.asnumpy(), mode='RGB') | |
data_list.append(fig) | |
# data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) | |
if idx % 1000 == 0: | |
print('loading bin', idx) | |
# # save img to '/home/yuange/dataset/LFW/rgb-arcface' | |
# img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C) | |
# # save_name = 'ind_' + str(idx) + '.bmp' | |
# # import os | |
# # save_name = os.path.join('/home/yuange/dataset/LFW/rgb-arcface', save_name) | |
# import PIL.Image as Image | |
# fig = Image.fromarray(img.asnumpy(), mode='RGB') | |
# # fig.save(save_name) | |
print('load finished', len(data_list)) | |
return data_list, issame_list | |
""" | |
Evaluation Benchmark | |
""" | |
class EvalDataset(data.Dataset): | |
def __init__(self, | |
target: str = 'lfw', | |
rec_folder: str = '', | |
transform = None, | |
crop_param = (0, 0, 112, 112) | |
): | |
print("=> Pre-loading images ...") | |
self.target = target | |
self.rec_folder = rec_folder | |
mx_reader = ReadMXNet(target, rec_folder) | |
path = os.path.join(rec_folder, target + ".bin") | |
all_img, issame_list = mx_reader.load_bin(path, (112, 112)) | |
self.all_img = all_img | |
self.issame_list = [] | |
for i in range(len(issame_list)): | |
flag = 0 if issame_list[i] else 1 # 0:is same | |
self.issame_list.append(flag) | |
self.transform = transform | |
if self.transform is None: | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
self.crop_param = crop_param | |
def __getitem__(self, index): | |
img1 = self.all_img[index * 2] | |
img2 = self.all_img[index * 2 + 1] | |
same = self.issame_list[index] | |
save_index = 11 | |
if index == save_index: | |
img1.save('img1_ori.jpg') | |
img2.save('img2_ori.jpg') | |
img1 = img1.crop(self.crop_param) | |
img2 = img2.crop(self.crop_param) | |
if index == save_index: | |
img1.save('img1_crop.jpg') | |
img2.save('img2_crop.jpg') | |
img1 = self.transform(img1) | |
img2 = self.transform(img2) | |
return img1, img2, same | |
def __len__(self): | |
return len(self.issame_list) | |
if __name__ == '__main__': | |
import PIL.Image as Image | |
import time | |
np.random.seed(1) | |
torch.manual_seed(1) | |
torch.cuda.manual_seed(1) | |
torch.cuda.manual_seed_all(1) | |
mx.random.seed(1) | |
is_gray = False | |
train_set = FaceByRandOccMask( | |
root_dir='/tmp/train_tmp/casia', | |
local_rank=0, | |
use_norm=True, | |
is_gray=is_gray, | |
) | |
start = time.time() | |
for idx in range(100): | |
face, mask, label = train_set.__getitem__(idx) | |
if idx < 15: | |
face = ((face + 1) * 128).numpy().astype(np.uint8) | |
face = np.transpose(face, (1, 2, 0)) | |
if is_gray: | |
face = Image.fromarray(face[:, :, 0], mode='L') | |
else: | |
face = Image.fromarray(face, mode='RGB') | |
face.save('face_{}.jpg'.format(idx)) | |
print('time cost: %d ms' % (int((time.time() - start) * 1000))) |