P-DFD / dataset /wild_deepfake.py
mrneuralnet's picture
Initial commit
982865f
raw
history blame
3.56 kB
import torch
import numpy as np
from os.path import join
from dataset import AbstractDataset
SPLITS = ["train", "test"]
class WildDeepfake(AbstractDataset):
"""
Wild Deepfake Dataset proposed in "WildDeepfake: A Challenging Real-World Dataset for Deepfake Detection"
"""
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
# pre-check
if cfg['split'] not in SPLITS:
raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.")
super(WildDeepfake, self).__init__(cfg, seed, transforms, transform, target_transform)
print(f"Loading data from 'WildDeepfake' of split '{cfg['split']}'"
f"\nPlease wait patiently...")
self.categories = ['original', 'fake']
self.root = cfg['root']
self.num_train = cfg.get('num_image_train', None)
self.num_test = cfg.get('num_image_test', None)
self.images, self.targets = self.__get_images()
print(f"Data from 'WildDeepfake' loaded.")
print(f"Dataset contains {len(self.images)} images.\n")
def __get_images(self):
if self.split == 'train':
num = self.num_train
elif self.split == 'test':
num = self.num_test
else:
num = None
real_images = torch.load(join(self.root, self.split, "real.pickle"))
if num is not None:
real_images = np.random.choice(real_images, num // 3, replace=False)
real_tgts = [torch.tensor(0)] * len(real_images)
print(f"real: {len(real_tgts)}")
fake_images = torch.load(join(self.root, self.split, "fake.pickle"))
if num is not None:
fake_images = np.random.choice(fake_images, num - num // 3, replace=False)
fake_tgts = [torch.tensor(1)] * len(fake_images)
print(f"fake: {len(fake_tgts)}")
return real_images + fake_images, real_tgts + fake_tgts
def __getitem__(self, index):
path = join(self.root, self.split, self.images[index])
tgt = self.targets[index]
return path, tgt
if __name__ == '__main__':
import yaml
config_path = "../config/dataset/wilddeepfake.yml"
with open(config_path) as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)
config = config["train_cfg"]
# config = config["test_cfg"]
def run_dataset():
dataset = WildDeepfake(config)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataset):
path, target = _
print(f"path: {path}, target: {target}")
if i >= 9:
break
def run_dataloader(display_samples=False):
from torch.utils import data
import matplotlib.pyplot as plt
dataset = WildDeepfake(config)
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataloader):
path, targets = _
image = dataloader.dataset.load_item(path)
print(f"image: {image.shape}, target: {targets}")
if display_samples:
plt.figure()
img = image[0].permute([1, 2, 0]).numpy()
plt.imshow(img)
# plt.savefig("./img_" + str(i) + ".png")
plt.show()
if i >= 9:
break
###########################
# run the functions below #
###########################
# run_dataset()
run_dataloader(False)