""" Approximate the bits/dimension for an image model. """ import argparse import os import numpy as np import torch.distributed as dist from improved_diffusion import dist_util, logger from improved_diffusion.image_datasets import load_data from improved_diffusion.script_util import ( model_and_diffusion_defaults, create_model_and_diffusion, add_dict_to_argparser, args_to_dict, ) def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys()) ) model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu") ) model.to(dist_util.dev()) model.eval() logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=args.class_cond, deterministic=True, ) logger.log("evaluating...") run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): all_bpd = [] all_metrics = {"vb": [], "mse": [], "xstart_mse": []} num_complete = 0 while num_complete < num_samples: batch, model_kwargs = next(data) batch = batch.to(dist_util.dev()) model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} minibatch_metrics = diffusion.calc_bpd_loop( model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) for key, term_list in all_metrics.items(): terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() dist.all_reduce(terms) term_list.append(terms.detach().cpu().numpy()) total_bpd = minibatch_metrics["total_bpd"] total_bpd = total_bpd.mean() / dist.get_world_size() dist.all_reduce(total_bpd) all_bpd.append(total_bpd.item()) num_complete += dist.get_world_size() * batch.shape[0] logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") if dist.get_rank() == 0: for name, terms in all_metrics.items(): out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") logger.log(f"saving {name} terms to {out_path}") np.savez(out_path, np.mean(np.stack(terms), axis=0)) dist.barrier() logger.log("evaluation complete") def create_argparser(): defaults = dict( data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() add_dict_to_argparser(parser, defaults) return parser if __name__ == "__main__": main()