File size: 2,619 Bytes
fc3814c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Video Face Manipulation Detection Through Ensemble of CNNs

Image and Sound Processing Lab - Politecnico di Milano

Nicolò Bonettini
Edoardo Daniele Cannas
Sara Mandelli
Luca Bondi
Paolo Bestagini
"""
from typing import List

import albumentations as A
import pandas as pd
from albumentations.pytorch import ToTensorV2

from .data import FrameFaceIterableDataset, get_iterative_real_fake_idxs


class FrameFaceTripletIterableDataset(FrameFaceIterableDataset):

    def __init__(self,
                 roots: List[str],
                 dfs: List[pd.DataFrame],
                 size: int,
                 scale: str,
                 num_triplets: int = -1,
                 transformer: A.BasicTransform = ToTensorV2(),
                 seed: int = None):
        """

        :param roots: List of root folders for frames cache
        :param dfs: List of DataFrames of cached frames with 'bb' column as array of 4 elements (left,top,right,bottom)
                   and 'label' column
        :param size: face size
        :param num_triplets: number of samples for the dataset
        :param idxs: sampling indexes triplets (each element is a key for anchor, positive, negative)
        :param scale: Rescale the face to the given size, preserving the aspect ratio.
                      If false crop around center to the given size
        :param transformer:
        :param seed:
        """
        super(FrameFaceTripletIterableDataset, self).__init__(
            roots=roots,
            dfs=dfs,
            size=size,
            scale=scale,
            num_samples=num_triplets * 3,
            transformer=transformer,
            seed=seed
        )

        self.num_triplet_couples = self.num_samples // 6
        self.num_triplets = self.num_triplet_couples * 2
        self.num_samples = self.num_triplets * 3

    def __len__(self):
        return self.num_triplets

    def __iter__(self):
        random_fake_idxs, random_real_idxs = get_iterative_real_fake_idxs(
            df_real=self.df_real,
            df_fake=self.df_fake,
            num_samples=self.num_samples,
            seed0=self.seed0
        )

        while len(random_fake_idxs) >= 3 and len(random_real_idxs) >= 3:
            a = self._get_face(random_fake_idxs.pop())[0]
            p = self._get_face(random_fake_idxs.pop())[0]
            n = self._get_face(random_real_idxs.pop())[0]
            yield a, p, n

            a = self._get_face(random_real_idxs.pop())[0]
            p = self._get_face(random_real_idxs.pop())[0]
            n = self._get_face(random_fake_idxs.pop())[0]
            yield a, p, n