Spaces:
Runtime error
Runtime error
| """ | |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation | |
| Official implementation of the paper: | |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" | |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis | |
| Licensed under a modified MIT license | |
| """ | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from prima.utils import recursive_to | |
| from prima.utils.evaluate_metric import Evaluator | |
| from prima.datasets.datasets import EvaluationDataset | |
| import argparse | |
| from torch.utils.data import DataLoader | |
| from prima.models.prima import PRIMA | |
| from prima.configs import get_config | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| def main(args): | |
| cfg = get_config(args.config) | |
| default_cfg = get_config(args.default_eval_config) | |
| model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False) | |
| model.eval() | |
| smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE) | |
| cfg_eval_dataset = dict(default_cfg.DATASETS) | |
| aug_cfg = cfg_eval_dataset.pop("CONFIG", None) # augmentation config is not used in evaluation | |
| if args.dataset.upper() == "ALL": | |
| for key in cfg_eval_dataset.keys(): | |
| print(f"-------- Evaluate {key} dataset ------------") | |
| eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model, | |
| evaluator=smal_evaluator, | |
| aug_cfg=aug_cfg, | |
| key=key, | |
| device=args.device) | |
| print(f"-------{key} Dataset evaluate finish ------") | |
| else: | |
| print(f"-------- Evaluate {args.dataset} dataset ------------") | |
| eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model, | |
| evaluator=smal_evaluator, | |
| aug_cfg=aug_cfg, | |
| key=args.dataset, | |
| device=args.device) | |
| print(f"-------{args.dataset} Dataset evaluate finish ------") | |
| def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'): | |
| dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'], | |
| json_file=dataset_cfg['JSON_FILE']['TEST'], | |
| augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000), | |
| image_size=cfg.MODEL.IMAGE_SIZE, | |
| ) | |
| dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS) | |
| bar = tqdm(dataloader) | |
| pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], [] | |
| for i, batch in enumerate(bar): | |
| batch = recursive_to(batch, device) | |
| with torch.no_grad(): | |
| output = model(batch) | |
| if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]: | |
| pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch) | |
| else: | |
| pa_mpjpe, pa_mpvpe = 0., 0. | |
| pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD) | |
| pa_mpjpe_list.append(pa_mpjpe) | |
| pa_mpvpe_list.append(pa_mpvpe) | |
| auc_list.append(auc) | |
| pck_list.append(pck) | |
| bar.set_postfix(PA_MPJPE=pa_mpjpe, | |
| PA_MPVPE=pa_mpvpe, | |
| AUC=auc, | |
| pck=pck,) | |
| print("---------------- 3D metric -----------------") | |
| print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}") | |
| print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}") | |
| print("--------------- 2D metric ------------------") | |
| print(f"AUC: {np.mean(auc_list)}") | |
| pck_list = np.array(pck_list) | |
| for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD): | |
| print(f"PCK@{th}: {np.mean(pck_list[:, _])}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, help="Path to config file", required=True) | |
| parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True) | |
| parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml") | |
| parser.add_argument("--dataset", type=str, default="ALL") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation") | |
| args = parser.parse_args() | |
| main(args) | |