import argparse import json import os import random from functools import partial from multiprocessing.pool import Pool from pathlib import Path os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" import pandas as pd from tqdm import tqdm from utils import get_original_with_fakes import cv2 cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) def get_paths(vid, label, root_dir): ori_vid, fake_vid = vid ori_dir = os.path.join(root_dir, "crops", ori_vid) fake_dir = os.path.join(root_dir, "crops", fake_vid) data = [] for frame in range(320): if frame % 10 != 0: continue for actor in range(2): image_id = "{}_{}.png".format(frame, actor) ori_img_path = os.path.join(ori_dir, image_id) fake_img_path = os.path.join(fake_dir, image_id) img_path = ori_img_path if label == 0 else fake_img_path try: # img = cv2.imread(img_path)[..., ::-1] if os.path.exists(img_path): data.append([img_path, label, ori_vid]) except: pass return data def parse_args(): parser = argparse.ArgumentParser( description="Generate Folds") parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake") parser.add_argument("--out", type=str, default="folds02.csv", help="CSV file to save") parser.add_argument("--seed", type=int, default=777, help="Seed to split, default 777") parser.add_argument("--n_splits", type=int, default=2, help="Num folds, default 10") args = parser.parse_args() return args def main(): args = parse_args() ori_fakes = get_original_with_fakes(args.root_dir) # sz = 50 // args.n_splits sz = 2 // args.n_splits folds = [] for fold in range(args.n_splits): # folds.append(list(range(sz * fold, sz * fold + sz if fold < args.n_splits - 1 else 50))) folds.append(list(range(sz * fold, sz * fold + sz if fold < args.n_splits - 1 else 2))) print(folds) video_fold = {} for d in os.listdir(args.root_dir): if "dfdc" in d: part = int(d.split("_")[-1]) for f in os.listdir(os.path.join(args.root_dir, d)): if "metadata.json" in f: with open(os.path.join(args.root_dir, d, "metadata.json")) as metadata_json: metadata = json.load(metadata_json) for k, v in metadata.items(): fold = None for i, fold_dirs in enumerate(folds): if part in fold_dirs: fold = i break assert fold is not None video_id = k[:-4] video_fold[video_id] = fold for fold in range(len(folds)): holdoutset = {k for k, v in video_fold.items() if v == fold} trainset = {k for k, v in video_fold.items() if v != fold} assert holdoutset.isdisjoint(trainset), "Folds have leaks" data = [] ori_ori = set([(ori, ori) for ori, fake in ori_fakes]) with Pool(processes=os.cpu_count()) as p: with tqdm(total=len(ori_ori)) as pbar: func = partial(get_paths, label=0, root_dir=args.root_dir) for v in p.imap_unordered(func, ori_ori): pbar.update() data.extend(v) with tqdm(total=len(ori_fakes)) as pbar: func = partial(get_paths, label=1, root_dir=args.root_dir) for v in p.imap_unordered(func, ori_fakes): pbar.update() data.extend(v) fold_data = [] for img_path, label, ori_vid in data: path = Path(img_path) video = path.parent.name file = path.name assert video_fold[video] == video_fold[ori_vid], "original video and fake have leak {} {}".format(ori_vid, video) fold_data.append([video, file, label, ori_vid, int(file.split("_")[0]), video_fold[video]]) random.shuffle(fold_data) pd.DataFrame(fold_data, columns=["video", "file", "label", "original", "frame", "fold"]).to_csv(args.out, index=False) if __name__ == '__main__': main() z = 2