Spaces:
Runtime error
Runtime error
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # -------------------------------------------------------- | |
| # Main test function | |
| # -------------------------------------------------------- | |
| import os | |
| import argparse | |
| import pickle | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import utils.misc as misc | |
| from models.croco_downstream import CroCoDownstreamBinocular | |
| from models.head_downstream import PixelwiseTaskWithDPT | |
| from stereoflow.criterion import * | |
| from stereoflow.datasets_stereo import get_test_datasets_stereo | |
| from stereoflow.datasets_flow import get_test_datasets_flow | |
| from stereoflow.engine import tiled_pred | |
| from stereoflow.datasets_stereo import vis_disparity | |
| from stereoflow.datasets_flow import flowToColor | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False) | |
| # important argument | |
| parser.add_argument( | |
| "--model", required=True, type=str, help="Path to the model to evaluate" | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| required=True, | |
| type=str, | |
| help="test dataset (there can be multiple dataset separated by a +)", | |
| ) | |
| # tiling | |
| parser.add_argument( | |
| "--tile_conf_mode", | |
| type=str, | |
| default="", | |
| help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--tile_overlap", type=float, default=0.7, help="overlap between tiles" | |
| ) | |
| # save (it will automatically go to <model_path>_<dataset_str>/<tile_str>_<save>) | |
| parser.add_argument( | |
| "--save", | |
| type=str, | |
| nargs="+", | |
| default=[], | |
| help="what to save: \ | |
| metrics (pickle file), \ | |
| pred (raw prediction save as torch tensor), \ | |
| visu (visualization in png of each prediction), \ | |
| err10 (visualization in png of the error clamp at 10 for each prediction), \ | |
| submission (submission file)", | |
| ) | |
| # other (no impact) | |
| parser.add_argument("--num_workers", default=4, type=int) | |
| return parser | |
| def _load_model_and_criterion(model_path, do_load_metrics, device): | |
| print("loading model from", model_path) | |
| assert os.path.isfile(model_path) | |
| ckpt = torch.load(model_path, "cpu") | |
| ckpt_args = ckpt["args"] | |
| task = ckpt_args.task | |
| tile_conf_mode = ckpt_args.tile_conf_mode | |
| num_channels = {"stereo": 1, "flow": 2}[task] | |
| with_conf = eval(ckpt_args.criterion).with_conf | |
| if with_conf: | |
| num_channels += 1 | |
| print("head: PixelwiseTaskWithDPT()") | |
| head = PixelwiseTaskWithDPT() | |
| head.num_channels = num_channels | |
| print("croco_args:", ckpt_args.croco_args) | |
| model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) | |
| msg = model.load_state_dict(ckpt["model"], strict=True) | |
| model.eval() | |
| model = model.to(device) | |
| if do_load_metrics: | |
| if task == "stereo": | |
| metrics = StereoDatasetMetrics().to(device) | |
| else: | |
| metrics = FlowDatasetMetrics().to(device) | |
| else: | |
| metrics = None | |
| return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode | |
| def _save_batch( | |
| pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None | |
| ): | |
| for i in range(len(pairnames)): | |
| pairname = ( | |
| eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i] | |
| ) # unbatch pairname | |
| fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) | |
| os.makedirs(os.path.dirname(fname), exist_ok=True) | |
| predi = pred[i, ...] | |
| if gt is not None: | |
| gti = gt[i, ...] | |
| if "pred" in save: | |
| torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth") | |
| if "visu" in save: | |
| if task == "stereo": | |
| disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy() | |
| m, M = None | |
| if gt is not None: | |
| mask = torch.isfinite(gti) | |
| m = gt[mask].min() | |
| M = gt[mask].max() | |
| img_disparity = vis_disparity(disparity, m=m, M=M) | |
| Image.fromarray(img_disparity).save(fname + "_pred.png") | |
| else: | |
| # normalize flowToColor according to the maxnorm of gt (or prediction if not available) | |
| flowNorm = ( | |
| torch.sqrt( | |
| torch.sum((gti if gt is not None else predi) ** 2, dim=0) | |
| ) | |
| .max() | |
| .item() | |
| ) | |
| imgflow = flowToColor( | |
| predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm | |
| ) | |
| Image.fromarray(imgflow).save(fname + "_pred.png") | |
| if "err10" in save: | |
| assert gt is not None | |
| L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0)) | |
| valid = torch.isfinite(gti[0, :, :]) | |
| L2err[~valid] = 0.0 | |
| L2err = torch.clamp(L2err, max=10.0) | |
| red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None] | |
| zer = torch.zeros_like(red) | |
| imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy() | |
| Image.fromarray(imgerr).save(fname + "_err10.png") | |
| if "submission" in save: | |
| assert submission_dir is not None | |
| predi_np = ( | |
| predi.permute(1, 2, 0).squeeze(2).cpu().numpy() | |
| ) # transform into HxWx2 for flow or HxW for stereo | |
| dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) | |
| def main(args): | |
| # load the pretrained model and metrics | |
| device = ( | |
| torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| model, metrics, cropsize, with_conf, task, tile_conf_mode = ( | |
| _load_model_and_criterion(args.model, "metrics" in args.save, device) | |
| ) | |
| if args.tile_conf_mode == "": | |
| args.tile_conf_mode = tile_conf_mode | |
| # load the datasets | |
| datasets = ( | |
| get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow | |
| )(args.dataset) | |
| dataloaders = [ | |
| DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| for dataset in datasets | |
| ] | |
| # run | |
| for i, dataloader in enumerate(dataloaders): | |
| dataset = datasets[i] | |
| dstr = args.dataset.split("+")[i] | |
| outdir = args.model + "_" + misc.filename(dstr) | |
| if "metrics" in args.save and len(args.save) == 1: | |
| fname = os.path.join( | |
| outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl" | |
| ) | |
| if os.path.isfile(fname) and len(args.save) == 1: | |
| print(" metrics already compute in " + fname) | |
| with open(fname, "rb") as fid: | |
| results = pickle.load(fid) | |
| for k, v in results.items(): | |
| print("{:s}: {:.3f}".format(k, v)) | |
| continue | |
| if "submission" in args.save: | |
| dirname = ( | |
| f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}" | |
| ) | |
| submission_dir = os.path.join(outdir, dirname) | |
| else: | |
| submission_dir = None | |
| print("") | |
| print("saving {:s} in {:s}".format("+".join(args.save), outdir)) | |
| print(repr(dataset)) | |
| if metrics is not None: | |
| metrics.reset() | |
| for data_iter_step, (image1, image2, gt, pairnames) in enumerate( | |
| tqdm(dataloader) | |
| ): | |
| do_flip = ( | |
| task == "stereo" | |
| and dstr.startswith("Spring") | |
| and any("right" in p for p in pairnames) | |
| ) # we flip the images and will flip the prediction after as we assume img1 is on the left | |
| image1 = image1.to(device, non_blocking=True) | |
| image2 = image2.to(device, non_blocking=True) | |
| gt = ( | |
| gt.to(device, non_blocking=True) if gt.numel() > 0 else None | |
| ) # special case for test time | |
| if do_flip: | |
| assert all("right" in p for p in pairnames) | |
| image1 = image1.flip( | |
| dims=[3] | |
| ) # this is already the right frame, let's flip it | |
| image2 = image2.flip(dims=[3]) | |
| gt = gt # that is ok | |
| with torch.inference_mode(): | |
| pred, _, _, time = tiled_pred( | |
| model, | |
| None, | |
| image1, | |
| image2, | |
| None if dataset.name == "Spring" else gt, | |
| conf_mode=args.tile_conf_mode, | |
| overlap=args.tile_overlap, | |
| crop=cropsize, | |
| with_conf=with_conf, | |
| return_time=True, | |
| ) | |
| if do_flip: | |
| pred = pred.flip(dims=[3]) | |
| if metrics is not None: | |
| metrics.add_batch(pred, gt) | |
| if any(k in args.save for k in ["pred", "visu", "err10", "submission"]): | |
| _save_batch( | |
| pred, | |
| gt, | |
| pairnames, | |
| dataset, | |
| task, | |
| args.save, | |
| outdir, | |
| time, | |
| submission_dir=submission_dir, | |
| ) | |
| if metrics is not None: | |
| results = metrics.get_results() | |
| for k, v in results.items(): | |
| print("{:s}: {:.3f}".format(k, v)) | |
| # save if needed | |
| if "metrics" in args.save: | |
| os.makedirs(os.path.dirname(fname), exist_ok=True) | |
| with open(fname, "wb") as fid: | |
| pickle.dump(results, fid) | |
| print("metrics saved in", fname) | |
| # finalize submission if needed | |
| if "submission" in args.save: | |
| dataset.finalize_submission(submission_dir) | |
| if __name__ == "__main__": | |
| args = get_args_parser() | |
| args = args.parse_args() | |
| main(args) | |