Spaces:
Runtime error
Runtime error
""" | |
Video Face Manipulation Detection Through Ensemble of CNNs | |
Image and Sound Processing Lab - Politecnico di Milano | |
Nicolò Bonettini | |
Edoardo Daniele Cannas | |
Sara Mandelli | |
Luca Bondi | |
Paolo Bestagini | |
""" | |
import os | |
from pathlib import Path | |
from typing import List | |
import albumentations as A | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image | |
from albumentations.pytorch import ToTensorV2 | |
from torch.utils.data import Dataset, IterableDataset | |
from .utils import extract_bb | |
def load_face(record: pd.Series, root: str, size: int, scale: str, transformer: A.BasicTransform) -> torch.Tensor: | |
path = os.path.join(str(root), str(record.name)) | |
autocache = size < 256 or scale == 'tight' | |
if scale in ['crop', 'scale', ]: | |
cached_path = str(Path(root).joinpath('autocache', scale, str(size), str(record.name)).with_suffix('.jpg')) | |
else: | |
# when self.scale == 'tight' the extracted face is not dependent on size | |
cached_path = str(Path(root).joinpath('autocache', scale, str(record.name)).with_suffix('.jpg')) | |
face = np.zeros((size, size, 3), dtype=np.uint8) | |
if os.path.exists(cached_path): | |
try: | |
face = Image.open(cached_path) | |
face = np.array(face) | |
if len(face.shape) != 3: | |
raise RuntimeError('Incorrect format: {}'.format(path)) | |
except KeyboardInterrupt as e: | |
# We want keybord interrupts to be propagated | |
raise e | |
except (OSError, IOError) as e: | |
print('Deleting corrupted cache file: {}'.format(cached_path)) | |
print(e) | |
os.unlink(cached_path) | |
face = np.zeros((size, size, 3), dtype=np.uint8) | |
if not os.path.exists(cached_path): | |
try: | |
frame = Image.open(path) | |
bb = record['left'], record['top'], record['right'], record['bottom'] | |
face = extract_bb(frame, bb=bb, size=size, scale=scale) | |
if autocache: | |
os.makedirs(os.path.dirname(cached_path), exist_ok=True) | |
face.save(cached_path, quality=95, subsampling='4:4:4') | |
face = np.array(face) | |
if len(face.shape) != 3: | |
raise RuntimeError('Incorrect format: {}'.format(path)) | |
except KeyboardInterrupt as e: | |
# We want keybord interrupts to be propagated | |
raise e | |
except (OSError, IOError) as e: | |
print('Error while reading: {}'.format(path)) | |
print(e) | |
face = np.zeros((size, size, 3), dtype=np.uint8) | |
face = transformer(image=face)['image'] | |
return face | |
class FrameFaceIterableDataset(IterableDataset): | |
def __init__(self, | |
roots: List[str], | |
dfs: List[pd.DataFrame], | |
size: int, scale: str, | |
num_samples: int = -1, | |
transformer: A.BasicTransform = ToTensorV2(), | |
output_index: bool = False, | |
labels_map: dict = None, | |
seed: int = None): | |
""" | |
:param roots: List of root folders for frames cache | |
:param dfs: List of DataFrames of cached frames with 'bb' column as array of 4 elements (left,top,right,bottom) | |
and 'label' column | |
:param size: face size | |
:param num_samples: | |
:param scale: Rescale the face to the given size, preserving the aspect ratio. | |
If false crop around center to the given size | |
:param transformer: | |
:param output_index: enable output of df_frames index | |
:param labels_map: map from 'REAL' and 'FAKE' to actual labels | |
""" | |
self.dfs = dfs | |
self.size = int(size) | |
self.seed0 = int(seed) if seed is not None else np.random.choice(2 ** 32) | |
# adapt indices | |
dfs_adapted = [df.copy() for df in self.dfs] | |
for df_idx, df in enumerate(dfs_adapted): | |
mi = pd.MultiIndex.from_tuples([(df_idx, key) for key in df.index], names=['df_idx', 'df_key']) | |
df.index = mi | |
# Concat | |
self.df = pd.concat(dfs_adapted, axis=0, join='inner') | |
self.df_real = self.df[self.df['label'] == 0] | |
self.df_fake = self.df[self.df['label'] == 1] | |
self.longer_set = 'real' if len(self.df_real) > len(self.df_fake) else 'fake' | |
self.num_samples = max(len(self.df_real), len(self.df_fake)) * 2 | |
self.num_samples = min(self.num_samples, num_samples) if num_samples > 0 else self.num_samples | |
self.output_idx = bool(output_index) | |
self.scale = str(scale) | |
self.roots = [str(r) for r in roots] | |
self.transformer = transformer | |
self.labels_map = labels_map | |
if self.labels_map is None: | |
self.labels_map = {False: np.array([0., ]), True: np.array([1., ])} | |
else: | |
self.labels_map = dict(self.labels_map) | |
def _get_face(self, item: pd.Index) -> (torch.Tensor, torch.Tensor) or (torch.Tensor, torch.Tensor, str): | |
record = self.dfs[item[0]].loc[item[1]] | |
face = load_face(record=record, | |
root=self.roots[item[0]], | |
size=self.size, | |
scale=self.scale, | |
transformer=self.transformer) | |
label = self.labels_map[record.label] | |
if self.output_idx: | |
return face, label, record.name | |
else: | |
return face, label | |
def __len__(self): | |
return self.num_samples | |
def __iter__(self): | |
random_fake_idxs, random_real_idxs = get_iterative_real_fake_idxs( | |
df_real=self.df_real, | |
df_fake=self.df_fake, | |
num_samples=self.num_samples, | |
seed0=self.seed0 | |
) | |
while len(random_fake_idxs) >= 1 and len(random_real_idxs) >= 1: | |
yield self._get_face(random_fake_idxs.pop()) | |
yield self._get_face(random_real_idxs.pop()) | |
def get_iterative_real_fake_idxs(df_real: pd.DataFrame, df_fake: pd.DataFrame, | |
num_samples: int, seed0: int): | |
longer_set = 'real' if len(df_real) > len(df_fake) else 'fake' | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is None: | |
seed = seed0 | |
np.random.seed(seed) | |
worker_num_couple_samples = num_samples // 2 | |
fake_idxs_portion = np.random.choice(df_fake.index, worker_num_couple_samples, | |
replace=longer_set == 'real') | |
real_idxs_portion = np.random.choice(df_real.index, worker_num_couple_samples, | |
replace=longer_set == 'fake') | |
else: | |
worker_id = worker_info.id | |
seed = seed0 + worker_id | |
np.random.seed(seed) | |
worker_num_couple_samples = (num_samples // 2) // worker_info.num_workers | |
if longer_set == 'fake': | |
fake_idxs_portion = df_fake.index[ | |
worker_id * worker_num_couple_samples:(worker_id + 1) * worker_num_couple_samples] | |
real_idxs_portion = np.random.choice(df_real.index, worker_num_couple_samples, replace=True) | |
else: | |
real_idxs_portion = df_real.index[ | |
worker_id * worker_num_couple_samples:(worker_id + 1) * worker_num_couple_samples] | |
fake_idxs_portion = np.random.choice(df_fake.index, worker_num_couple_samples, | |
replace=True) | |
random_fake_idxs = list(np.random.permutation(fake_idxs_portion)) | |
random_real_idxs = list(np.random.permutation(real_idxs_portion)) | |
assert (len(random_fake_idxs) == len(random_real_idxs)) | |
return random_fake_idxs, random_real_idxs | |
class FrameFaceDatasetTest(Dataset): | |
def __init__(self, root: str, df: pd.DataFrame, | |
size: int, scale: str, | |
transformer: A.BasicTransform = ToTensorV2(), | |
labels_map: dict = None, | |
aug_transformers: List[A.BasicTransform] = None): | |
""" | |
:param root: root folder for frames cache | |
:param df: DataFrame of cached frames with 'bb' column as array of 4 elements (left,top,right,bottom) | |
and 'label' column | |
:param size: face size | |
:param num_samples: | |
:param scale: Rescale the face to the given size, preserving the aspect ratio. | |
If false crop around center to the given size | |
:param transformer: | |
:param labels_map: dcit to map df labels | |
:param aug_transformers: if not None, creates multiple copies of the same sample according to the provided augmentations | |
""" | |
self.df = df | |
self.size = int(size) | |
self.scale = str(scale) | |
self.root = str(root) | |
self.transformer = transformer | |
self.aug_transformers = aug_transformers | |
self.labels_map = labels_map | |
if self.labels_map is None: | |
self.labels_map = {False: np.array([0., ]), True: np.array([1., ])} | |
else: | |
self.labels_map = dict(self.labels_map) | |
def _get_face(self, item: pd.Index) -> (torch.Tensor, torch.Tensor) or (torch.Tensor, torch.Tensor, str): | |
record = self.df.loc[item] | |
label = self.labels_map[record.label] | |
if self.aug_transformers is None: | |
face = load_face(record=record, | |
root=self.root, | |
size=self.size, | |
scale=self.scale, | |
transformer=self.transformer) | |
return face, label | |
else: | |
faces = [] | |
for aug_transf in self.aug_transformers: | |
faces.append( | |
load_face(record=record, | |
root=self.root, | |
size=self.size, | |
scale=self.scale, | |
transformer=A.Compose([aug_transf, self.transformer]) | |
)) | |
faces = torch.stack(faces) | |
return faces, label | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, item): | |
return self._get_face(self.df.index[item]) | |