|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torchvision.datasets as datasets |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
from torchvision.io import read_video |
|
from torch.utils.data import Dataset |
|
from random import random, choice, shuffle |
|
from io import BytesIO |
|
from PIL import Image |
|
from PIL import ImageFile |
|
from scipy.ndimage.filters import gaussian_filter |
|
import pickle |
|
import os |
|
|
|
|
|
MEAN = { |
|
"imagenet":[0.485, 0.456, 0.406], |
|
"clip":[0.48145466, 0.4578275, 0.40821073] |
|
} |
|
|
|
STD = { |
|
"imagenet":[0.229, 0.224, 0.225], |
|
"clip":[0.26862954, 0.26130258, 0.27577711] |
|
} |
|
|
|
|
|
|
|
def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]): |
|
out = [] |
|
for r, d, f in os.walk(rootdir): |
|
for file in f: |
|
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): |
|
out.append(os.path.join(r, file)) |
|
return out |
|
|
|
def get_list(path, must_contain=''): |
|
image_list = recursively_read(path, must_contain) |
|
return image_list |
|
|
|
|
|
def uniform_capture_video_frames(path, num_frames=16): |
|
capture = cv2.VideoCapture(path) |
|
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frame_interval = total_frames // num_frames |
|
|
|
frames = [] |
|
frame_count = 0 |
|
|
|
while len(frames) < num_frames: |
|
ret, frame = capture.read() |
|
if not ret: |
|
break |
|
if frame_count % frame_interval == 0: |
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
pil_image = Image.fromarray(frame_rgb) |
|
frames.append(pil_image) |
|
frame_count += 1 |
|
|
|
capture.release() |
|
return frames |
|
|
|
|
|
class RealFakeDataset(Dataset): |
|
def __init__(self, opt, clip_model=None, transform=None, num_frames=16): |
|
self.opt = opt |
|
assert opt.data_label in ["train", "val"] |
|
self.data_label = opt.data_label |
|
self.num_frames = num_frames |
|
real_list = get_list( os.path.join(opt.real_list_path) ) |
|
fake_list = get_list( os.path.join(opt.fake_list_path) ) |
|
|
|
|
|
|
|
self.labels_dict = {} |
|
for i in real_list: |
|
self.labels_dict[i] = 0 |
|
for i in fake_list: |
|
self.labels_dict[i] = 1 |
|
|
|
self.total_list = real_list + fake_list |
|
shuffle(self.total_list) |
|
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
return len(self.total_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
img_path = self.total_list[idx] |
|
label = self.labels_dict[img_path] |
|
|
|
|
|
|
|
|
|
frames, _, _ = read_video(str(img_path), pts_unit='sec') |
|
frames = frames[:self.num_frames] |
|
frames = frames.permute(0, 3, 1, 2) |
|
|
|
if self.transform is not None: |
|
video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) |
|
|
|
return video_frames, label |