File size: 9,800 Bytes
1ecb721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""k-NN precision and recall."""

from time import time


# ----------------------------------------------------------------------------

import numpy as np
from tqdm import tqdm


def batch_pairwise_distances(U, V):
    """Compute pair-wise distance in a batch of feature."""

    norm_u = np.sum(np.square(U), axis=1)
    norm_v = np.sum(np.square(V), axis=1)

    norm_u = np.reshape(norm_u, [-1, 1])
    norm_v = np.reshape(norm_v, [1, -1])

    D = np.maximum(norm_u - 2 * np.dot(U, V.T) + norm_v, 0.0)
    return D


# ----------------------------------------------------------------------------

class DistanceBlock():
    """Compute pair-wise distance in a batch of feature."""

    def __init__(self, num_features):
        self.num_features = num_features

    def pairwise_distances(self, U, V):
        return batch_pairwise_distances(U, V)



# ----------------------------------------------------------------------------

class ManifoldEstimator():
    """Estimates the manifold of given feature vectors."""

    def __init__(self, distance_block, features, row_batch_size=16, col_batch_size=16,

                 nhood_sizes=[3], clamp_to_percentile=None, eps=1e-5, mute=False):
        """Estimate the manifold of given feature vectors.



            Args:

                distance_block: DistanceBlock object that distributes pairwise distance

                    calculation to multiple GPUs.

                features (np.array/tf.Tensor): Matrix of feature vectors to estimate their manifold.

                row_batch_size (int): Row batch size to compute pairwise distances

                    (parameter to trade-off between memory usage and performance).

                col_batch_size (int): Column batch size to compute pairwise distances.

                nhood_sizes (list): Number of neighbors used to estimate the manifold.

                clamp_to_percentile (float): Prune hyperspheres that have radius larger than

                    the given percentile.

                eps (float): Small number for numerical stability.

        """
        num_images = features.shape[0]
        self.nhood_sizes = nhood_sizes
        self.num_nhoods = len(nhood_sizes)
        self.eps = eps
        self.row_batch_size = row_batch_size
        self.col_batch_size = col_batch_size
        self._ref_features = features
        self._distance_block = distance_block
        self.mute = mute

        # Estimate manifold of features by calculating distances to k-NN of each sample.
        self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
        distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float32)
        seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)

        if mute:
            for begin1 in range(0, num_images, row_batch_size):
                end1 = min(begin1 + row_batch_size, num_images)
                row_batch = features[begin1:end1]

                for begin2 in range(0, num_images, col_batch_size):
                    end2 = min(begin2 + col_batch_size, num_images)
                    col_batch = features[begin2:end2]

                    # Compute distances between batches.
                    distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
                                                                                                           col_batch)

                # Find the k-nearest neighbor from the current batch.
                self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]
        else:
            for begin1 in tqdm(range(0, num_images, row_batch_size)):
                end1 = min(begin1 + row_batch_size, num_images)
                row_batch = features[begin1:end1]

                for begin2 in range(0, num_images, col_batch_size):
                    end2 = min(begin2 + col_batch_size, num_images)
                    col_batch = features[begin2:end2]

                    # Compute distances between batches.
                    distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
                                                                                                           col_batch)

                # Find the k-nearest neighbor from the current batch.
                self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]

        if clamp_to_percentile is not None:
            max_distances = np.percentile(self.D, clamp_to_percentile, axis=0)
            self.D[self.D > max_distances] = 0

    def evaluate(self, eval_features, return_realism=False, return_neighbors=False):
        """Evaluate if new feature vectors are at the manifold."""
        num_eval_images = eval_features.shape[0]
        num_ref_images = self.D.shape[0]
        distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
        batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
        max_realism_score = np.zeros([num_eval_images, ], dtype=np.float32)
        nearest_indices = np.zeros([num_eval_images, ], dtype=np.int32)

        for begin1 in range(0, num_eval_images, self.row_batch_size):
            end1 = min(begin1 + self.row_batch_size, num_eval_images)
            feature_batch = eval_features[begin1:end1]

            for begin2 in range(0, num_ref_images, self.col_batch_size):
                end2 = min(begin2 + self.col_batch_size, num_ref_images)
                ref_batch = self._ref_features[begin2:end2]

                distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch,
                                                                                                       ref_batch)

            # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
            # If a feature vector is inside a hypersphere of some reference sample, then
            # the new sample lies at the estimated manifold.
            # The radii of the hyperspheres are determined from distances of neighborhood size k.
            samples_in_manifold = distance_batch[0:end1 - begin1, :, None] <= self.D
            batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)

            max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1 - begin1, :] + self.eps),
                                                    axis=1)
            nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1 - begin1, :], axis=1)

        if return_realism and return_neighbors:
            return batch_predictions, max_realism_score, nearest_indices
        elif return_realism:
            return batch_predictions, max_realism_score
        elif return_neighbors:
            return batch_predictions, nearest_indices

        return batch_predictions


# ----------------------------------------------------------------------------

def knn_precision_recall_features(ref_features, eval_features, nhood_sizes=[3],

                                  row_batch_size=10000, col_batch_size=50000, mute=False):
    """Calculates k-NN precision and recall for two sets of feature vectors.



        Args:

            ref_features (np.array/tf.Tensor): Feature vectors of reference images.

            eval_features (np.array/tf.Tensor): Feature vectors of generated images.

            nhood_sizes (list): Number of neighbors used to estimate the manifold.

            row_batch_size (int): Row batch size to compute pairwise distances

                (parameter to trade-off between memory usage and performance).

            col_batch_size (int): Column batch size to compute pairwise distances.

            num_gpus (int): Number of GPUs used to evaluate precision and recall.



        Returns:

            State (dict): Dict that contains precision and recall calculated from

            ref_features and eval_features.

    """
    state = dict()
    num_images = ref_features.shape[0]
    num_features = ref_features.shape[1]

    # Initialize DistanceBlock and ManifoldEstimators.
    distance_block = DistanceBlock(num_features)
    ref_manifold = ManifoldEstimator(distance_block, ref_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)
    eval_manifold = ManifoldEstimator(distance_block, eval_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)

    # Evaluate precision and recall using k-nearest neighbors.
    if not mute:
        print('Evaluating k-NN precision and recall with %i samples...' % num_images)
    start = time()

    # Precision: How many points from eval_features are in ref_features manifold.
    precision = ref_manifold.evaluate(eval_features)
    state['precision'] = precision.mean(axis=0)

    # Recall: How many points from ref_features are in eval_features manifold.
    recall = eval_manifold.evaluate(ref_features)
    state['recall'] = recall.mean(axis=0)

    if not mute:
        print('Evaluated k-NN precision and recall in: %gs' % (time() - start))

    return state

# ----------------------------------------------------------------------------