File size: 3,286 Bytes
e8e478e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import cv2
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.io import read_video
from torch.utils.data import Dataset
from random import random, choice, shuffle
from io import BytesIO
from PIL import Image
from PIL import ImageFile
from scipy.ndimage.filters import gaussian_filter
import pickle
import os 


MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}



def recursively_read(rootdir, must_contain, exts=["mp4", "avi"]):
    out = [] 
    for r, d, f in os.walk(rootdir):
        for file in f:
            if (file.split('.')[1] in exts)  and  (must_contain in os.path.join(r, file)):
                out.append(os.path.join(r, file))
    return out

def get_list(path, must_contain=''):
    image_list = recursively_read(path, must_contain)
    return image_list


def uniform_capture_video_frames(path, num_frames=16):
    capture = cv2.VideoCapture(path)
    total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_interval = total_frames // num_frames

    frames = []
    frame_count = 0

    while len(frames) < num_frames:
        ret, frame = capture.read()
        if not ret:
            break
        if frame_count % frame_interval == 0:
            # 将OpenCV的BGR图像转换为RGB图像
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 将numpy数组转换为PIL图像
            pil_image = Image.fromarray(frame_rgb)
            frames.append(pil_image)
        frame_count += 1

    capture.release()
    return frames


class RealFakeDataset(Dataset):
    def __init__(self, opt, clip_model=None, transform=None, num_frames=16):
        self.opt = opt
        assert opt.data_label in ["train", "val"]
        self.data_label  = opt.data_label
        self.num_frames = num_frames
        real_list = get_list( os.path.join(opt.real_list_path) )
        fake_list = get_list( os.path.join(opt.fake_list_path) )


        # setting the labels for the dataset
        self.labels_dict = {}
        for i in real_list:
            self.labels_dict[i] = 0
        for i in fake_list:
            self.labels_dict[i] = 1

        self.total_list = real_list + fake_list
        shuffle(self.total_list)

        self.transform = transform


    def __len__(self):
        return len(self.total_list)


    def __getitem__(self, idx):
        img_path = self.total_list[idx]
        label = self.labels_dict[img_path]
        # img = Image.open(img_path).convert("RGB")
        # video_frames = uniform_capture_video_frames(img_path, num_frames=32)
        # self.clip_model.to(torch.device('cuda:{}'.format(self.opt.gpu_ids[0])) if self.opt.gpu_ids else torch.device('cpu'))
        
        frames, _, _ = read_video(str(img_path), pts_unit='sec')
        frames = frames[:self.num_frames]
        frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)

        if self.transform is not None:
            video_frames = torch.cat([self.transform(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])

        return video_frames, label