Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from os.path import join | |
from dataset import AbstractDataset | |
METHOD = ['all', 'Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures'] | |
SPLIT = ['train', 'val', 'test'] | |
COMP2NAME = {'c0': 'raw', 'c23': 'c23', 'c40': 'c40'} | |
SOURCE_MAP = {'youtube': 2, 'Deepfakes': 3, 'Face2Face': 4, 'FaceSwap': 5, 'NeuralTextures': 6} | |
class FaceForensics(AbstractDataset): | |
""" | |
FaceForensics++ Dataset proposed in "FaceForensics++: Learning to Detect Manipulated Facial Images" | |
""" | |
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): | |
# pre-check | |
if cfg['split'] not in SPLIT: | |
raise ValueError(f"split should be one of {SPLIT}, " | |
f"but found {cfg['split']}.") | |
if cfg['method'] not in METHOD: | |
raise ValueError(f"method should be one of {METHOD}, " | |
f"but found {cfg['method']}.") | |
if cfg['compression'] not in COMP2NAME.keys(): | |
raise ValueError(f"compression should be one of {COMP2NAME.keys()}, " | |
f"but found {cfg['compression']}.") | |
super(FaceForensics, self).__init__( | |
cfg, seed, transforms, transform, target_transform) | |
print(f"Loading data from 'FF++ {cfg['method']}' of split '{cfg['split']}' " | |
f"and compression '{cfg['compression']}'\nPlease wait patiently...") | |
self.categories = ['original', 'fake'] | |
# load the path of dataset images | |
indices = join(self.root, cfg['split'] + "_" + cfg['compression'] + ".pickle") | |
indices = torch.load(indices) | |
if cfg['method'] == "all": | |
# full dataset | |
self.images = [join(cfg['root'], _[0]) for _ in indices] | |
self.targets = [_[1] for _ in indices] | |
else: | |
# specific manipulated method | |
self.images = list() | |
self.targets = list() | |
nums = 0 | |
for _ in indices: | |
if cfg['method'] in _[0]: | |
self.images.append(join(cfg['root'], _[0])) | |
self.targets.append(_[1]) | |
nums = len(self.targets) | |
ori = list() | |
for _ in indices: | |
if "original_sequences" in _[0]: | |
ori.append(join(cfg['root'], _[0])) | |
choices = np.random.choice(ori, size=nums, replace=False) | |
self.images.extend(choices) | |
self.targets.extend([0] * nums) | |
print("Data from 'FF++' loaded.\n") | |
print(f"Dataset contains {len(self.images)} images.\n") | |
if __name__ == '__main__': | |
import yaml | |
config_path = "../config/dataset/faceforensics.yml" | |
with open(config_path) as config_file: | |
config = yaml.load(config_file, Loader=yaml.FullLoader) | |
config = config["train_cfg"] | |
# config = config["test_cfg"] | |
def run_dataset(): | |
dataset = FaceForensics(config) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataset): | |
path, target = _ | |
print(f"path: {path}, target: {target}") | |
if i >= 9: | |
break | |
def run_dataloader(display_samples=False): | |
from torch.utils import data | |
import matplotlib.pyplot as plt | |
dataset = FaceForensics(config) | |
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataloader): | |
path, targets = _ | |
image = dataloader.dataset.load_item(path) | |
print(f"image: {image.shape}, target: {targets}") | |
if display_samples: | |
plt.figure() | |
img = image[0].permute([1, 2, 0]).numpy() | |
plt.imshow(img) | |
# plt.savefig("./img_" + str(i) + ".png") | |
plt.show() | |
if i >= 9: | |
break | |
########################### | |
# run the functions below # | |
########################### | |
# run_dataset() | |
run_dataloader(False) | |