File size: 5,202 Bytes
a00ee36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
""" BigGAN: The Authorized Unofficial PyTorch release
    Code by A. Brock and A. Andonian
    This code is an unofficial reimplementation of
    "Large-Scale GAN Training for High Fidelity Natural Image Synthesis,"
    by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096).

    Let's go.
"""

import os
import numpy as np
from tqdm import tqdm, trange
import json

from imageio import imwrite as imsave

# Import my stuff
import sys

sys.path.insert(1, os.path.join(sys.path[0], ".."))
import inference.utils as inference_utils
import BigGAN_PyTorch.utils as biggan_utils


class Tester:
    def __init__(self, config):
        self.config = vars(config) if not isinstance(config, dict) else config

    def __call__(self) -> float:
        # Seed RNG
        biggan_utils.seed_rng(self.config["seed"])

        import torch

        # Setup cudnn.benchmark for free speed
        torch.backends.cudnn.benchmark = True

        self.config = biggan_utils.update_config_roots(
            self.config, change_weight_folder=False
        )
        # Prepare root folders if necessary
        biggan_utils.prepare_root(self.config)

        # Load model
        self.G, self.config = inference_utils.load_model_inference(self.config)
        biggan_utils.count_parameters(self.G)

        # Get sampling function and reference statistics for FID
        print("Eval reference set is ", self.config["eval_reference_set"])
        sample, im_reference_filename = inference_utils.get_sampling_funct(
            self.config,
            self.G,
            instance_set=self.config["eval_instance_set"],
            reference_set=self.config["eval_reference_set"],
            which_dataset=self.config["which_dataset"],
        )

        if config["which_dataset"] == "coco":
            image_format = "jpg"
        else:
            image_format = "png"
        if (
            self.config["eval_instance_set"] == "val"
            and config["which_dataset"] == "coco"
        ):
            # using evaluation set
            test_part = True
        else:
            test_part = False
        path_samples = os.path.join(
            self.config["samples_root"],
            self.config["experiment_name"],
            "%s_images_seed%i%s%s%s"
            % (
                config["which_dataset"],
                config["seed"],
                "_test" if test_part else "",
                "_hd" + str(self.config["filter_hd"])
                if self.config["filter_hd"] > -1
                else "",
                ""
                if self.config["kmeans_subsampled"] == -1
                else "_" + str(self.config["kmeans_subsampled"]) + "centers",
            ),
        )

        print("Path samples will be ", path_samples)
        if not os.path.exists(path_samples):
            os.makedirs(path_samples)

        if not os.path.exists(
            os.path.join(self.config["samples_root"], self.config["experiment_name"])
        ):
            os.mkdir(
                os.path.join(
                    self.config["samples_root"], self.config["experiment_name"]
                )
            )
        print(
            "Sampling %d images and saving them with %s format..."
            % (self.config["sample_num_npz"], image_format)
        )
        counter_i = 0
        for i in trange(
            int(
                np.ceil(
                    self.config["sample_num_npz"] / float(self.config["batch_size"])
                )
            )
        ):
            with torch.no_grad():
                images, labels, _ = sample()

                fake_imgs = images.cpu().detach().numpy().transpose(0, 2, 3, 1)
                if self.config["model_backbone"] == "biggan":
                    fake_imgs = fake_imgs * 0.5 + 0.5
                elif self.config["model_backbone"] == "stylegan2":
                    fake_imgs = np.clip((fake_imgs * 127.5 + 128), 0, 255).astype(
                        np.uint8
                    )
                for fake_img in fake_imgs:
                    imsave(
                        "%s/%06d.%s" % (path_samples, counter_i, image_format), fake_img
                    )
                    counter_i += 1
                    if counter_i >= self.config["sample_num_npz"]:
                        break


if __name__ == "__main__":
    parser = biggan_utils.prepare_parser()
    parser = biggan_utils.add_sample_parser(parser)
    parser = inference_utils.add_backbone_parser(parser)
    config = vars(parser.parse_args())
    config["n_classes"] = 1000
    if config["json_config"] != "":
        data = json.load(open(config["json_config"]))
        for key in data.keys():
            if "exp_name" in key:
                config["experiment_name"] = data[key]
            else:
                config[key] = data[key]
    else:
        print("No json file to load configuration from")

    tester = Tester(config)

    tester()