Spaces:
Running
Running
| """ | |
| Approximate the bits/dimension for an image model. | |
| """ | |
| import argparse | |
| import os, json | |
| import torch as th | |
| import numpy as np | |
| import torch.distributed as dist | |
| from improved_diffusion import dist_util, logger | |
| from improved_diffusion.image_datasets import load_data | |
| from improved_diffusion.text_datasets import load_data_text, load_synthetic_data | |
| from improved_diffusion.script_util import ( | |
| model_and_diffusion_defaults, | |
| create_model_and_diffusion, | |
| add_dict_to_argparser, | |
| args_to_dict, | |
| ) | |
| from functools import partial | |
| from transformers import set_seed | |
| from improved_diffusion.test_util import get_weights, denoised_fn_round, compute_logp, load_results | |
| def main(): | |
| set_seed(42) | |
| args = create_argparser().parse_args() | |
| # load configurations. | |
| config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json") | |
| print(config_path) | |
| # sys.setdefaultencoding('utf-8') | |
| with open(config_path, 'rb', ) as f: | |
| training_args = json.load(f) | |
| training_args['batch_size'] = args.batch_size | |
| print(args.data_dir) | |
| del training_args['data_dir'] | |
| # print(args.__dict__, training_args) | |
| args.__dict__.update(training_args) | |
| print(args.__dict__['batch_size'], training_args['batch_size'], args.clip_denoised, args.batch_size) | |
| print(args.data_dir) | |
| # if args.noise_level > 0.0: flag_noise=True #DEBUG | |
| args.noise_level = 0.0 | |
| args.roc_train = 'diffusion_lm/ROCstory' | |
| if args.modality == 'roc-aug': | |
| args.modality = 'roc' | |
| # DEBUG | |
| dist_util.setup_dist() | |
| logger.configure() | |
| logger.log("creating model and diffusion...") | |
| model, diffusion = create_model_and_diffusion( | |
| **args_to_dict(args, model_and_diffusion_defaults().keys()) | |
| ) | |
| model.load_state_dict(th.load(args.model_path)) | |
| # model.load_state_dict( | |
| # dist_util.load_state_dict(args.model_path, map_location="cpu") | |
| # ) | |
| # diffusion.rescale_timesteps = False # IMPORTANT DEBUG --> REMOVE | |
| model.to(dist_util.dev()) | |
| model.eval() # DEBUG | |
| logger.log("creating data loader...") | |
| if args.modality == 'image': | |
| data = load_data( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| deterministic=True, | |
| ) | |
| elif args.modality == 'permuted_image': | |
| # perm = np.arange(args.image_size * args.image_size) | |
| # np.random.shuffle(perm) | |
| model_path_base = os.path.split(args.model_path)[0] | |
| print(f'load permutation to {model_path_base}/permutation.json') | |
| with open(f'{model_path_base}/permutation.json', 'r') as f: | |
| perm = json.load(f) | |
| perm = np.array(perm) | |
| data = load_data( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| permutation=perm | |
| ) | |
| elif args.modality == 'synth': | |
| from improved_diffusion.rounding import load_models | |
| model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel, | |
| os.path.split(args.model_path)[0]) | |
| data = load_synthetic_data( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| data_args=args, | |
| model=model2, | |
| split='train', | |
| # split='valid', | |
| deterministic=True | |
| ) | |
| elif args.modality == 'pos': | |
| from improved_diffusion.rounding import load_models | |
| model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel, | |
| os.path.split(args.model_path)[0]) | |
| data = load_synthetic_data( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| data_args=args, | |
| model=model2, | |
| pos=True, | |
| deterministic = True | |
| ) | |
| else: | |
| from improved_diffusion.rounding import load_models | |
| model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel, | |
| os.path.split(args.model_path)[0]) | |
| # print(tokenizer) | |
| # rev_tokenizer = {k:int(v) for k, v in tokenizer.items()} | |
| rev_tokenizer = {v:k for k, v in tokenizer.items()} | |
| if args.training_mode == 'e2e': | |
| print('e2e, load the right model embeddings', '*'*80) | |
| model2.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu()) | |
| # print(rev_tokenizer) | |
| data = load_data_text( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| data_args=args, | |
| model=model2, | |
| deterministic=True, | |
| task_mode=args.modality, | |
| padding_mode=args.padding_mode, # block, pad | |
| split=args.split, | |
| load_vocab=rev_tokenizer, | |
| ) | |
| logger.log("evaluating...") | |
| run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised, args, model2) | |
| def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised, args, model2): | |
| all_bpd = [] | |
| all_metrics = {"vb": [], "mse": [], "xstart_mse": []} | |
| num_complete = 0 | |
| model3 = get_weights(model2, args) | |
| while num_complete < num_samples: | |
| batch, model_kwargs = next(data) | |
| batch = batch.to(dist_util.dev()) | |
| model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} | |
| model_kwargs['mapping_func'] = partial(compute_logp, args, model3.cuda()) | |
| minibatch_metrics = diffusion.calc_bpd_loop( | |
| model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, | |
| # denoised_fn=None, | |
| denoised_fn=partial(denoised_fn_round, args, model3.cuda()) if args.clamp == 'clamp' else None, | |
| ) | |
| for key, term_list in all_metrics.items(): | |
| terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() | |
| dist.all_reduce(terms) | |
| term_list.append(terms.detach().cpu().numpy()) | |
| total_bpd = minibatch_metrics["total_bpd"] | |
| total_bpd = total_bpd.mean() / dist.get_world_size() | |
| dist.all_reduce(total_bpd) | |
| all_bpd.append(total_bpd.item()) | |
| num_complete += dist.get_world_size() * batch.shape[0] | |
| logger.log(f"done {num_complete} samples on {args.split}: bpd={np.mean(all_bpd)}, " | |
| f"per token={np.mean(all_bpd) * args.in_channel} ", args.model_path) | |
| temp_cat = np.mean(np.stack(all_metrics['vb']), axis=0) | |
| if len(temp_cat) % 8 == 0: | |
| print([y.sum() for y in np.split(np.mean(np.stack(all_metrics['vb']), axis=0), 8)]) | |
| else: | |
| print(temp_cat[0].sum()) | |
| print([y.sum() for y in np.split(temp_cat[1:-1], 8)]) | |
| print(temp_cat[-1].sum()) | |
| vb_temp = np.mean(np.stack(all_metrics['vb']), axis=0) | |
| print(vb_temp.shape, vb_temp.sum()) | |
| print(vb_temp[-10:]) | |
| if dist.get_rank() == 0: | |
| for name, terms in all_metrics.items(): | |
| model_base_name = os.path.basename( | |
| os.path.split(args.model_path)[0]) + f'.{os.path.split(args.model_path)[1]}' | |
| # args.out_dir = os.path.join(args.out_dir, f"{model_base_name}.samples_{shape_str}.txt") | |
| out_path = os.path.join(args.out_dir, f"{model_base_name}.{name}_{args.split}_{args.clamp}_terms.npz") | |
| logger.log(f"saving {name} terms to {out_path}") | |
| np.savez(out_path, np.mean(np.stack(terms), axis=0)) | |
| dist.barrier() | |
| logger.log("evaluation complete") | |
| if 'ema' in args.model_path: | |
| json_path = os.path.join(os.path.split(args.model_path)[0], f'ema_score_{args.split}_nll.json') | |
| elif args.clamp == 'noclamp': | |
| json_path = os.path.join(os.path.split(args.model_path)[0], f'score_{args.split}_nll_noclamp.json') | |
| else: | |
| json_path = os.path.join(os.path.split(args.model_path)[0], f'score_{args.split}_nll.json') | |
| print(f'written to {json_path}') | |
| temp_cat = np.mean(np.stack(all_metrics['vb']), axis=0) | |
| if len(temp_cat) % 8 == 0: | |
| temp_cat = temp_cat | |
| else: | |
| temp_cat = temp_cat[1:-1] | |
| json_dict = { | |
| f'score_{args.split}_ppl_token': np.mean(all_bpd) * args.in_channel, | |
| f'score_{args.split}_ppl_dim': np.mean(all_bpd), | |
| f'break_down_{args.split}_dim' : [y.sum().item() for y in np.split(temp_cat, 8)], | |
| f'last_10_{args.split}_dim': vb_temp[-10:].tolist(), | |
| 'source_file': out_path, | |
| 'num_samples':num_samples, | |
| } | |
| load_results(json_path, json_dict) | |
| def create_argparser(): | |
| defaults = dict( | |
| data_dir="", clip_denoised=False, num_samples=128, batch_size=64, model_path="", | |
| out_dir="diffusion_lm/improved_diffusion/scores", | |
| emb_scale_factor=1.0, split='train', debug_path='', clamp='clamp', | |
| ) | |
| defaults.update(model_and_diffusion_defaults()) | |
| parser = argparse.ArgumentParser() | |
| add_dict_to_argparser(parser, defaults) | |
| return parser | |
| if __name__ == "__main__": | |
| main() | |