zdou0830's picture
desco
749745d
raw
history blame
13 kB
import time
import pickle
import logging
import os
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
from yaml import safe_dump
from yacs.config import load_cfg, CfgNode # , _to_dict
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus
from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from maskrcnn_benchmark.utils.flops import profile
choice = lambda x: x[np.random.randint(len(x))] if isinstance(x, tuple) else choice(tuple(x))
def gather_candidates(all_candidates):
all_candidates = all_gather(all_candidates)
all_candidates = [cand for candidates in all_candidates for cand in candidates]
return list(set(all_candidates))
def gather_stats(all_candidates):
all_candidates = all_gather(all_candidates)
reduced_statcs = {}
for candidates in all_candidates:
reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists
return reduced_statcs
def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for _, batch in enumerate(data_loader):
images, targets, image_ids = batch
with torch.no_grad():
output = model(images.to(device), rngs=rngs)
output = [o.to(cpu_device) for o in output]
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)})
return results_dict
def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500):
for name, param in model.named_buffers():
if "running_mean" in name:
nn.init.constant_(param, 0)
if "running_var" in name:
nn.init.constant_(param, 1)
model.train()
for iteration, (images, targets, _) in enumerate(data_loader, 1):
images = images.to(device)
targets = [target.to(device) for target in targets]
with torch.no_grad():
loss_dict = model(images, targets, rngs)
if iteration >= max_iter:
break
return model
def inference(
model,
rngs,
data_loader,
iou_types=("bbox",),
box_only=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
):
# convert to a torch.device for efficiency
device = torch.device(device)
dataset = data_loader.dataset
predictions = compute_on_dataset(model, rngs, data_loader, device)
# wait for all processes to complete before measuring the time
synchronize()
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
return
extra_args = dict(
box_only=box_only,
iou_types=iou_types,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
return evaluate(dataset=dataset, predictions=predictions, output_folder=output_folder, **extra_args)
def fitness(cfg, model, rngs, val_loaders):
iou_types = ("bbox",)
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
for data_loader_val in val_loaders:
results = inference(
model,
rngs,
data_loader_val,
iou_types=iou_types,
box_only=False,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
)
synchronize()
return results
class EvolutionTrainer(object):
def __init__(self, cfg, model, flops_limit=None, is_distributed=True):
self.log_dir = cfg.OUTPUT_DIR
self.checkpoint_name = os.path.join(self.log_dir, "evolution.pth")
self.is_distributed = is_distributed
self.states = model.module.mix_nums if is_distributed else model.mix_nums
self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict()))
self.flops_limit = flops_limit
self.model = model
self.candidates = []
self.vis_dict = {}
self.max_epochs = cfg.SEARCH.MAX_EPOCH
self.select_num = cfg.SEARCH.SELECT_NUM
self.population_num = cfg.SEARCH.POPULATION_NUM / get_world_size()
self.mutation_num = cfg.SEARCH.MUTATION_NUM / get_world_size()
self.crossover_num = cfg.SEARCH.CROSSOVER_NUM / get_world_size()
self.mutation_prob = cfg.SEARCH.MUTATION_PROB / get_world_size()
self.keep_top_k = {self.select_num: [], 50: []}
self.epoch = 0
self.cfg = cfg
def save_checkpoint(self):
if not is_main_process():
return
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
info = {}
info["candidates"] = self.candidates
info["vis_dict"] = self.vis_dict
info["keep_top_k"] = self.keep_top_k
info["epoch"] = self.epoch
torch.save(info, self.checkpoint_name)
print("Save checkpoint to", self.checkpoint_name)
def load_checkpoint(self):
if not os.path.exists(self.checkpoint_name):
return False
info = torch.load(self.checkpoint_name)
self.candidates = info["candidates"]
self.vis_dict = info["vis_dict"]
self.keep_top_k = info["keep_top_k"]
self.epoch = info["epoch"]
print("Load checkpoint from", self.checkpoint_name)
return True
def legal(self, cand):
assert isinstance(cand, tuple) and len(cand) == len(self.states)
if cand in self.vis_dict:
return False
if self.flops_limit is not None:
net = self.model.module.backbone if self.is_distributed else self.model.backbone
inp = (1, 3, 224, 224)
flops, params = profile(net, inp, extra_args={"paths": list(cand)})
flops = flops / 1e6
print("flops:", flops)
if flops > self.flops_limit:
return False
return True
def update_top_k(self, candidates, *, k, key, reverse=False):
assert k in self.keep_top_k
# print('select ......')
t = self.keep_top_k[k]
t += candidates
t.sort(key=key, reverse=reverse)
self.keep_top_k[k] = t[:k]
def eval_candidates(self, train_loader, val_loader):
for cand in self.candidates:
t0 = time.time()
# load back supernet state dict
self.model.load_state_dict(self.supernet_state_dict)
# bn_statistic
model = bn_statistic(self.model, list(cand), train_loader)
# fitness
evals = fitness(cfg, model, list(cand), val_loader)
if is_main_process():
acc = evals[0].results["bbox"]["AP"]
self.vis_dict[cand] = acc
print("candiate ", cand)
print("time: {}s".format(time.time() - t0))
print("acc ", acc)
def stack_random_cand(self, random_func, *, batchsize=10):
while True:
cands = [random_func() for _ in range(batchsize)]
for cand in cands:
yield cand
def random_can(self, num):
# print('random select ........')
candidates = []
cand_iter = self.stack_random_cand(lambda: tuple(np.random.randint(i) for i in self.states))
while len(candidates) < num:
cand = next(cand_iter)
if not self.legal(cand):
continue
candidates.append(cand)
# print('random {}/{}'.format(len(candidates),num))
# print('random_num = {}'.format(len(candidates)))
return candidates
def get_mutation(self, k, mutation_num, m_prob):
assert k in self.keep_top_k
# print('mutation ......')
res = []
iter = 0
max_iters = mutation_num * 10
def random_func():
cand = list(choice(self.keep_top_k[k]))
for i in range(len(self.states)):
if np.random.random_sample() < m_prob:
cand[i] = np.random.randint(self.states[i])
return tuple(cand)
cand_iter = self.stack_random_cand(random_func)
while len(res) < mutation_num and max_iters > 0:
cand = next(cand_iter)
if not self.legal(cand):
continue
res.append(cand)
# print('mutation {}/{}'.format(len(res),mutation_num))
max_iters -= 1
# print('mutation_num = {}'.format(len(res)))
return res
def get_crossover(self, k, crossover_num):
assert k in self.keep_top_k
# print('crossover ......')
res = []
iter = 0
max_iters = 10 * crossover_num
def random_func():
p1 = choice(self.keep_top_k[k])
p2 = choice(self.keep_top_k[k])
return tuple(choice([i, j]) for i, j in zip(p1, p2))
cand_iter = self.stack_random_cand(random_func)
while len(res) < crossover_num and max_iters > 0:
cand = next(cand_iter)
if not self.legal(cand):
continue
res.append(cand)
# print('crossover {}/{}'.format(len(res),crossover_num))
max_iters -= 1
# print('crossover_num = {}'.format(len(res)))
return res
def train(self, train_loader, val_loader):
logger = logging.getLogger("maskrcnn_benchmark.evolution")
if not self.load_checkpoint():
self.candidates = gather_candidates(self.random_can(self.population_num))
while self.epoch < self.max_epochs:
self.eval_candidates(train_loader, val_loader)
self.vis_dict = gather_stats(self.vis_dict)
self.update_top_k(self.candidates, k=self.select_num, key=lambda x: 1 - self.vis_dict[x])
self.update_top_k(self.candidates, k=50, key=lambda x: 1 - self.vis_dict[x])
if is_main_process():
logger.info("Epoch {} : top {} result".format(self.epoch + 1, len(self.keep_top_k[self.select_num])))
for i, cand in enumerate(self.keep_top_k[self.select_num]):
logger.info(" No.{} {} perf = {}".format(i + 1, cand, self.vis_dict[cand]))
mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob))
crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num))
rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover)))
self.candidates = mutation + crossover + rand
self.epoch += 1
self.save_checkpoint()
def save_candidates(self, cand, template):
paths = self.keep_top_k[self.select_num][cand - 1]
with open(template, "r") as f:
super_cfg = load_cfg(f)
search_spaces = {}
for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH:
search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops]
search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP
layer_setup = []
for i, layer in enumerate(search_layers):
name, setup = get_layer_name(layer, search_spaces)
if not isinstance(name, list):
name = [name]
name = name[paths[i]]
layer_setup.append("('{}', {})".format(name, str(setup)[1:-1]))
super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup
cand_cfg = _to_dict(super_cfg)
del cand_cfg["MODEL"]["BACKBONE"]["LAYER_SEARCH"]
with open(
os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace(".yaml", "_cand{}.yaml".format(cand)),
"w",
) as f:
f.writelines(safe_dump(cand_cfg))
super_weight = self.supernet_state_dict
cand_weight = OrderedDict()
cand_keys = ["layers.{}.ops.{}".format(i, c) for i, c in enumerate(paths)]
for key, val in super_weight.items():
if "ops" in key:
for ck in cand_keys:
if ck in key:
cand_weight[key.replace(ck, ck.split(".ops.")[0])] = val
else:
cand_weight[key] = val
torch.save({"model": cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, "init_cand{}.pth".format(cand)))