Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from os.path import join | |
from dataset import AbstractDataset | |
SPLITS = ["train", "test"] | |
class WildDeepfake(AbstractDataset): | |
""" | |
Wild Deepfake Dataset proposed in "WildDeepfake: A Challenging Real-World Dataset for Deepfake Detection" | |
""" | |
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): | |
# pre-check | |
if cfg['split'] not in SPLITS: | |
raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.") | |
super(WildDeepfake, self).__init__(cfg, seed, transforms, transform, target_transform) | |
print(f"Loading data from 'WildDeepfake' of split '{cfg['split']}'" | |
f"\nPlease wait patiently...") | |
self.categories = ['original', 'fake'] | |
self.root = cfg['root'] | |
self.num_train = cfg.get('num_image_train', None) | |
self.num_test = cfg.get('num_image_test', None) | |
self.images, self.targets = self.__get_images() | |
print(f"Data from 'WildDeepfake' loaded.") | |
print(f"Dataset contains {len(self.images)} images.\n") | |
def __get_images(self): | |
if self.split == 'train': | |
num = self.num_train | |
elif self.split == 'test': | |
num = self.num_test | |
else: | |
num = None | |
real_images = torch.load(join(self.root, self.split, "real.pickle")) | |
if num is not None: | |
real_images = np.random.choice(real_images, num // 3, replace=False) | |
real_tgts = [torch.tensor(0)] * len(real_images) | |
print(f"real: {len(real_tgts)}") | |
fake_images = torch.load(join(self.root, self.split, "fake.pickle")) | |
if num is not None: | |
fake_images = np.random.choice(fake_images, num - num // 3, replace=False) | |
fake_tgts = [torch.tensor(1)] * len(fake_images) | |
print(f"fake: {len(fake_tgts)}") | |
return real_images + fake_images, real_tgts + fake_tgts | |
def __getitem__(self, index): | |
path = join(self.root, self.split, self.images[index]) | |
tgt = self.targets[index] | |
return path, tgt | |
if __name__ == '__main__': | |
import yaml | |
config_path = "../config/dataset/wilddeepfake.yml" | |
with open(config_path) as config_file: | |
config = yaml.load(config_file, Loader=yaml.FullLoader) | |
config = config["train_cfg"] | |
# config = config["test_cfg"] | |
def run_dataset(): | |
dataset = WildDeepfake(config) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataset): | |
path, target = _ | |
print(f"path: {path}, target: {target}") | |
if i >= 9: | |
break | |
def run_dataloader(display_samples=False): | |
from torch.utils import data | |
import matplotlib.pyplot as plt | |
dataset = WildDeepfake(config) | |
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataloader): | |
path, targets = _ | |
image = dataloader.dataset.load_item(path) | |
print(f"image: {image.shape}, target: {targets}") | |
if display_samples: | |
plt.figure() | |
img = image[0].permute([1, 2, 0]).numpy() | |
plt.imshow(img) | |
# plt.savefig("./img_" + str(i) + ".png") | |
plt.show() | |
if i >= 9: | |
break | |
########################### | |
# run the functions below # | |
########################### | |
# run_dataset() | |
run_dataloader(False) | |