File size: 5,514 Bytes
a00ee36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
""" Obtain nearest neighbors and store them in a HDF5 file. """
import os
import sys
from argparse import ArgumentParser
from tqdm import tqdm, trange
import h5py as h5
import numpy as np
import torch
import utils
def prepare_parser():
usage = "Parser for ImageNet HDF5 scripts."
parser = ArgumentParser(description=usage)
parser.add_argument(
"--resolution",
type=int,
default=128,
help="Which Dataset resolution to train on, out of 64, 128, 256 (default: %(default)s)",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Which Dataset to convert: train, val (default: %(default)s)",
)
parser.add_argument(
"--data_root",
type=str,
default="data",
help="Default location where data is stored (default: %(default)s)",
)
parser.add_argument(
"--out_path",
type=str,
default="data",
help="Default location where data in hdf5 format will be stored (default: %(default)s)",
)
parser.add_argument(
"--num_workers",
type=int,
default=16,
help="Number of dataloader workers (default: %(default)s)",
)
parser.add_argument(
"--chunk_size",
type=int,
default=500,
help="Default overall batchsize (default: %(default)s)",
)
parser.add_argument(
"--compression",
action="store_true",
default=False,
help="Use LZF compression? (default: %(default)s)",
)
parser.add_argument(
"--feature_extractor",
type=str,
default="classification",
choices=["classification", "selfsupervised"],
help="Choice of feature extractor",
)
parser.add_argument(
"--backbone_feature_extractor",
type=str,
default="resnet50",
choices=["resnet50"],
help="Choice of feature extractor backbone",
)
parser.add_argument(
"--k_nn",
type=int,
default=100,
help="Number of nearest neighbors (default: %(default)s)",
)
parser.add_argument(
"--which_dataset", type=str, default="imagenet", help="Dataset choice."
)
return parser
def run(config):
# Update compression entry
config["compression"] = (
"lzf" if config["compression"] else None
) # No compression; can also use 'lzf'
test_part = False
if config["split"] == "test":
config["split"] = "val"
test_part = True
if config["which_dataset"] in ["imagenet", "imagenet_lt"]:
dataset_name_prefix = "ILSVRC"
elif config["which_dataset"] == "coco":
dataset_name_prefix = "COCO"
else:
dataset_name_prefix = config["which_dataset"]
train_dataset = utils.get_dataset_hdf5(
**{
"resolution": config["resolution"],
"data_path": config["data_root"],
"load_in_mem_feats": True,
"compute_nns": True,
"longtail": config["which_dataset"] == "imagenet_lt",
"split": config["split"],
"instance_cond": True,
"feature_extractor": config["feature_extractor"],
"backbone_feature_extractor": config["backbone_feature_extractor"],
"k_nn": config["k_nn"],
"ddp": False,
"which_dataset": config["which_dataset"],
"test_part": test_part,
}
)
all_nns = np.array(train_dataset.sample_nns)[:, : config["k_nn"]]
all_nns_radii = train_dataset.kth_values[:, config["k_nn"]]
print("NNs shape ", all_nns.shape, all_nns_radii.shape)
labels_ = torch.Tensor(train_dataset.labels)
acc = np.array(
[(labels_[all_nns[:, i_nn]] == labels_).sum() for i_nn in range(config["k_nn"])]
).sum() / (len(labels_) * config["k_nn"])
print("For k ", config["k_nn"], " accuracy:", acc)
h5file_name_nns = config["out_path"] + "/%s%i%s%s%s_feats_%s_%s_nn_k%i.hdf5" % (
dataset_name_prefix,
config["resolution"],
"" if config["which_dataset"] != "imagenet_lt" else "longtail",
"_val" if config["split"] == "val" else "",
"_test" if test_part else "",
config["feature_extractor"],
config["backbone_feature_extractor"],
config["k_nn"],
)
print("Filename is ", h5file_name_nns)
with h5.File(h5file_name_nns, "w") as f:
nns_dset = f.create_dataset(
"sample_nns",
all_nns.shape,
dtype="int64",
maxshape=all_nns.shape,
chunks=(config["chunk_size"], all_nns.shape[1]),
compression=config["compression"],
)
nns_dset[...] = all_nns
nns_radii_dset = f.create_dataset(
"sample_nns_radius",
all_nns_radii.shape,
dtype="float",
maxshape=all_nns_radii.shape,
chunks=(config["chunk_size"],),
compression=config["compression"],
)
nns_radii_dset[...] = all_nns_radii
def main():
# parse command line and run
parser = prepare_parser()
config = vars(parser.parse_args())
print(config)
run(config)
if __name__ == "__main__":
main()
|