| | """ |
| | Sample Command |
| | """ |
| | import os, sys, logging, argparse |
| | import torch |
| | from torch import nn |
| | import numpy as np |
| | import torch.nn.functional as F |
| |
|
| | from stldm import * |
| | import utilspp as utpp |
| | from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20 |
| | from data.loader import GET_TestLoader |
| | from data.dutils import resize |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set') |
| | |
| | parser.add_argument('-f', dest='checkpt', type=str, default='', help='model checkpoint to be loaded from (Empty = not loading)') |
| | parser.add_argument('-m', '--model', type=str, default='', help='the model definition to be created') |
| | parser.add_argument('--type', type=str, default='3D', help='Determine which kind of model to use, 2D or 3D') |
| | parser.add_argument('--c_str', type=float, default=0.0, help='CFG strength') |
| | parser.add_argument('--e_id', type=int, default=0, help='Ensemble ID') |
| | |
| | parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader') |
| | parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size') |
| | |
| | parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss') |
| | parser.add_argument('-o', '--output', default=None, help='The path to save the log files') |
| | args = parser.parse_args() |
| |
|
| | |
| | if args.output is None: |
| | path_list = args.checkpt.split("/") |
| | logfile_name = os.path.join(*path_list[:-1], 'logs', f'{path_list[-1][:-3]}.log') |
| | else: |
| | logfile_name = f'{args.output}.log' |
| | logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s') |
| | logging.info(f'Model checkpoint: {args.checkpt}') |
| | logging.info(f'Steps: {args.step}') |
| |
|
| | sampler_dir = os.path.join(*logfile_name.split("/")[:-2], f'CFG={args.c_str}_samples') |
| | os.makedirs(sampler_dir, exist_ok=True) |
| |
|
| | |
| | dataset_config = globals()[args.dataset] |
| | dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta'] |
| | loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size) |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | |
| | |
| | assert args.type in ['2D', '3D'], 'Please specify either 2D or 3D' |
| | model_config = globals()[args.model] |
| | model = n2n_setup[args.type](model_config, print_info=True, cfg_str=args.c_str if args.c_str != 0.0 else None).to(device) |
| | logging.info(f'CFG Scheduler: Const-{args.c_str}') |
| | |
| | data = torch.load(args.checkpt, map_location=device) |
| | if 'model' in data.keys(): |
| | model.load_state_dict(data['model']) |
| | else: |
| | model.load_state_dict(data) |
| | |
| | |
| | in_len, out_len = model_config['vp_param']['shape_in'][0], model_config['vp_param']['shape_out'][0] |
| | img_size = model_config['vp_param']['shape_in'][-1] |
| | |
| | step = 0 |
| | out = [] |
| | while args.step < 0 or step <=args.step: |
| | model.eval() |
| |
|
| | if dataset_meta['dataset'] == 'HKO-7': |
| | setattr(args, 'seq_len', in_len) |
| | try: |
| | data = loader.sample(batch_size=args.batch_size) |
| | except Exception as e: |
| | logging.error(e) |
| | break |
| | x_seq, x_mask, dt_clip, _ = data |
| | x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args) |
| | elif dataset_meta['dataset'] == 'SEVIR': |
| | data = loader.sample(batch_size=args.batch_size) |
| | if data is None: |
| | break |
| | x, y = data['vil'][:, :in_len], data['vil'][:, in_len:] |
| | elif dataset_meta['dataset'].startswith('meteo'): |
| | try: |
| | x, y = next(loader) |
| | except Exception as e: |
| | logging.error(e) |
| | break |
| | |
| | x, y = x.to(device), y.to(device) |
| |
|
| | with torch.no_grad(): |
| | if x.shape[-1] != img_size: |
| | x = resize(x, img_size) |
| | y = resize(y, img_size) |
| | if model_config['pre'] is not None: |
| | x = model_config['pre'](x) |
| |
|
| | y_pred = model(x) |
| | |
| | if model_config['post'] is not None: |
| | x = model_config['post'](x) |
| | y_pred = model_config['post'](y_pred) |
| | y_pred = y_pred.clamp(0,1) |
| | |
| | out.append(y_pred.detach().cpu()) |
| |
|
| | step += 1 |
| | |
| | if step == 1 or step % args.print_every == 0: |
| | logging.info(f'{step} Steps Generated, {len(out)} in out array') |
| | |
| | logging.info(f'{step} Steps Generated, {len(out)} in out array') |
| | out = torch.cat(out, dim=0) |
| | out = out.numpy() |
| | save_path = os.path.join(sampler_dir, f'BTCHW_total-no:{len(out)}_e={args.e_id}.npy') |
| | np.save(save_path, out) |
| | print('Output saved in', save_path) |