File size: 5,348 Bytes
5a510e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# pylint: disable=R0801
"""
This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
provides methods for data augmentation, getting items from the dataset, and determining the length of the
dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
PIL, and transformers.
"""

import json
import random
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPImageProcessor


class FaceMaskDataset(Dataset):
    """
    FaceMaskDataset is a custom dataset for face mask images.
    
    Args:
        img_size (int): The size of the input images.
        drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
        data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
        sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.

    Attributes:
        img_size (int): The size of the input images.
        drop_ratio (float): The ratio of dropped pixels during data augmentation.
        data_meta_paths (list): The paths to the metadata files containing image paths and labels.
        sample_margin (int): The margin for sampling regions in the image.
        processor (CLIPImageProcessor): The image processor for preprocessing images.
        transform (transforms.Compose): The image augmentation transform.
    """

    def __init__(
        self,
        img_size,
        drop_ratio=0.1,
        data_meta_paths=None,
        sample_margin=30,
    ):
        super().__init__()

        self.img_size = img_size
        self.sample_margin = sample_margin

        vid_meta = []
        for data_meta_path in data_meta_paths:
            with open(data_meta_path, "r", encoding="utf-8") as f:
                vid_meta.extend(json.load(f))
        self.vid_meta = vid_meta
        self.length = len(self.vid_meta)

        self.clip_image_processor = CLIPImageProcessor()

        self.transform = transforms.Compose(
            [
                transforms.Resize(self.img_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.Resize(self.img_size),
                transforms.ToTensor(),
            ]
        )

        self.drop_ratio = drop_ratio

    def augmentation(self, image, transform, state=None):
        """
        Apply data augmentation to the input image.

        Args:
            image (PIL.Image): The input image.
            transform (torchvision.transforms.Compose): The data augmentation transforms.
            state (dict, optional): The random state for reproducibility. Defaults to None.

        Returns:
            PIL.Image: The augmented image.
        """
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["image_path"]
        mask_path = video_meta["mask_path"]
        face_emb_path = video_meta["face_emb"]

        video_frames = sorted(Path(video_path).iterdir())
        video_length = len(video_frames)

        margin = min(self.sample_margin, video_length)

        ref_img_idx = random.randint(0, video_length - 1)
        if ref_img_idx + margin < video_length:
            tgt_img_idx = random.randint(
                ref_img_idx + margin, video_length - 1)
        elif ref_img_idx - margin > 0:
            tgt_img_idx = random.randint(0, ref_img_idx - margin)
        else:
            tgt_img_idx = random.randint(0, video_length - 1)

        ref_img_pil = Image.open(video_frames[ref_img_idx])
        tgt_img_pil = Image.open(video_frames[tgt_img_idx])

        tgt_mask_pil = Image.open(mask_path)

        assert ref_img_pil is not None, "Fail to load reference image."
        assert tgt_img_pil is not None, "Fail to load target image."
        assert tgt_mask_pil is not None, "Fail to load target mask."

        state = torch.get_rng_state()
        tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
        tgt_mask_img = self.augmentation(
            tgt_mask_pil, self.cond_transform, state)
        tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
        ref_img_vae = self.augmentation(
            ref_img_pil, self.transform, state)
        face_emb = torch.load(face_emb_path)


        sample = {
            "video_dir": video_path,
            "img": tgt_img,
            "tgt_mask": tgt_mask_img,
            "ref_img": ref_img_vae,
            "face_emb": face_emb,
        }

        return sample

    def __len__(self):
        return len(self.vid_meta)


if __name__ == "__main__":
    data = FaceMaskDataset(img_size=(512, 512))
    train_dataloader = torch.utils.data.DataLoader(
        data, batch_size=4, shuffle=True, num_workers=1
    )
    for step, batch in enumerate(train_dataloader):
        print(batch["tgt_mask"].shape)
        break