File size: 7,643 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import sys
import os
import os.path as osp
import argparse
import hashlib
import tarfile
import time
import urllib.request
from lib import GENFORCE, GENFORCE_MODELS, SFD, ARCFACE, FAIRFACE, HOPENET, AUDET, CELEBA_ATTRIBUTES, ContraCLIP_models


def reporthook(count, block_size, total_size):
    global start_time
    if count == 0:
        start_time = time.time()
        return
    duration = time.time() - start_time
    progress_size = int(count * block_size)
    speed = int(progress_size / (1024 * duration))
    percent = min(int(count * block_size * 100 / total_size), 100)
    sys.stdout.write("\r      \\__%d%%, %d MB, %d KB/s, %d seconds passed" %
                     (percent, progress_size / (1024 * 1024), speed, duration))

    sys.stdout.flush()


def download(src, sha256sum, dest):
    tmp_tar = osp.join(dest, ".tmp.tar")
    try:
        urllib.request.urlretrieve(src, tmp_tar, reporthook)
    except:
        raise ConnectionError("Error: {}".format(src))

    sha256_hash = hashlib.sha256()
    with open(tmp_tar, "rb") as f:
        # Read and update hash string value in blocks of 4K
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)

        sha256_check = sha256_hash.hexdigest() == sha256sum
        print()
        print("      \\__Check sha256: {}".format("OK!" if sha256_check else "Error"))
        if not sha256_check:
            raise Exception("Error: Invalid sha256 sum: {}".format(sha256_hash.hexdigest()))

    tar_file = tarfile.open(tmp_tar, mode='r')
    tar_file.extractall(dest)
    os.remove(tmp_tar)


def main():
    """Download pre-trained GAN generators and various pre-trained detectors (used only during testing), as well as
    pre-trained ContraCLIP models:
    -- GenForce GAN generators [1]
    -- SFD face detector [2]
    -- ArcFace [3]
    -- FairFace [4]
    -- Hopenet [5]
    -- AU detector [6] for 12 DISFA [7] Action Units
    -- Facial attributes detector [8] for 5 CelebA [9] attributes
    -- ContraCLIP [10] pre-trained models:
        StyleGAN2@FFHQ
        ProgGAN@CelebA-HQ:
        StyleGAN2@AFHQ-Cats
        StyleGAN2@AFHQ-Dogs
        StyleGAN2@AFHQ-Cars

    References:
         [1] https://genforce.github.io/
         [2] Zhang, Shifeng, et al. "S3FD: Single shot scale-invariant face detector." Proceedings of the IEEE
             international conference on computer vision. 2017.
         [3] Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
             Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
         [4] Karkkainen, Kimmo, and Jungseock Joo. "FairFace: Face attribute dataset for balanced race, gender, and age."
             arXiv preprint arXiv:1908.04913 (2019).
         [5] Doosti, Bardia, et al. "Hope-net: A graph-based model for hand-object pose estimation." Proceedings of the
             IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
         [6] Ntinou, Ioanna, et al. "A transfer learning approach to heatmap regression for action unit intensity
             estimation." IEEE Transactions on Affective Computing (2021).
         [7] Mavadati, S. Mohammad, et al. "DISFA: A spontaneous facial action intensity database." IEEE Transactions on
             Affective Computing 4.2 (2013): 151-160.
         [8] Jiang, Yuming, et al. "Talk-to-Edit: Fine-Grained Facial Editing via Dialog." Proceedings of the IEEE/CVF
             International Conference on Computer Vision. 2021.
         [9] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE international
             conference on computer vision. 2015.
        [10] Tzelepis, C., Oldfield, J., Tzimiropoulos, G., & Patras, I. (2022). ContraCLIP: Interpretable GAN
             generation driven by pairs of contrasting sentences. arXiv preprint arXiv:2206.02104.
    """
    parser = argparse.ArgumentParser(description="Download pre-trained models")
    parser.add_argument('-m', '--contraclip-models', action='store_true', help="download pre-trained ContraCLIP models")
    args = parser.parse_args()

    # Create pre-trained models root directory
    pretrained_models_root = osp.join('models', 'pretrained')
    os.makedirs(pretrained_models_root, exist_ok=True)

    # Download the following pre-trained GAN generators (under models/pretrained/)
    print("#. Download pre-trained GAN generators...")
    print("  \\__.GenForce")
    download_genforce_models = False
    for k, v in GENFORCE_MODELS.items():
        if not osp.exists(osp.join(pretrained_models_root, 'genforce', v[0])):
            download_genforce_models = True
            break
    if download_genforce_models:
        download(src=GENFORCE[0], sha256sum=GENFORCE[1], dest=pretrained_models_root)
    else:
        print("      \\__Already exists.")

    print("#. Download pre-trained ArcFace model...")
    print("  \\__.ArcFace")
    if osp.exists(osp.join(pretrained_models_root, 'arcface', 'model_ir_se50.pth')):
        print("      \\__Already exists.")
    else:
        download(src=ARCFACE[0], sha256sum=ARCFACE[1], dest=pretrained_models_root)

    print("#. Download pre-trained SFD face detector model...")
    print("  \\__.Face detector (SFD)")
    if osp.exists(osp.join(pretrained_models_root, 'sfd', 's3fd-619a316812.pth')):
        print("      \\__Already exists.")
    else:
        download(src=SFD[0], sha256sum=SFD[1], dest=pretrained_models_root)

    print("#. Download pre-trained FairFace model...")
    print("  \\__.FairFace")
    if osp.exists(osp.join(pretrained_models_root, 'fairface', 'fairface_alldata_4race_20191111.pt')) and \
            osp.exists(osp.join(pretrained_models_root, 'fairface', 'res34_fair_align_multi_7_20190809.pt')):
        print("      \\__Already exists.")
    else:
        download(src=FAIRFACE[0], sha256sum=FAIRFACE[1], dest=pretrained_models_root)

    print("#. Download pre-trained Hopenet model...")
    print("  \\__.Hopenet")
    if osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha1.pkl')) and \
            osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha2.pkl')) and \
            osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_robust_alpha1.pkl')):
        print("      \\__Already exists.")
    else:
        download(src=HOPENET[0], sha256sum=HOPENET[1], dest=pretrained_models_root)

    print("#. Download pre-trained AU detector model...")
    print("  \\__.FANet")
    if osp.exists(osp.join(pretrained_models_root, 'au_detector', 'disfa_adaptation_f0.pth')):
        print("      \\__Already exists.")
    else:
        download(src=AUDET[0], sha256sum=AUDET[1], dest=pretrained_models_root)

    print("#. Download pre-trained CelebA attributes predictors models...")
    print("  \\__.CelebA")
    if osp.exists(osp.join(pretrained_models_root, 'celeba_attributes', 'eval_predictor.pth.tar')):
        print("      \\__Already exists.")
    else:
        download(src=CELEBA_ATTRIBUTES[0], sha256sum=CELEBA_ATTRIBUTES[1], dest=pretrained_models_root)

    # Download pre-trained ContraCLIP models
    if args.contraclip_models:
        pretrained_contraclip_root = osp.join('experiments', 'complete')
        os.makedirs(pretrained_contraclip_root, exist_ok=True)

        print("#. Download pre-trained ContraCLIP models...")
        print("  \\__.ContraCLIP pre-trained models...")
        download(src=ContraCLIP_models[0],
                 sha256sum=ContraCLIP_models[1],
                 dest=pretrained_contraclip_root)


if __name__ == '__main__':
    main()