""" Export the predictions of a model for a given dataloader (e.g. ImageFolder). Use a standalone script with `python3 -m dsfm.scipts.export_predictions dir` or call from another script. """ from pathlib import Path import h5py import numpy as np import torch from tqdm import tqdm from .tensor import batch_to_device @torch.no_grad() def export_predictions( loader, model, output_file, as_half=False, keys="*", callback_fn=None, optional_keys=[], ): assert keys == "*" or isinstance(keys, (tuple, list)) Path(output_file).parent.mkdir(exist_ok=True, parents=True) hfile = h5py.File(str(output_file), "w") device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device).eval() for data_ in tqdm(loader): data = batch_to_device(data_, device, non_blocking=True) pred = model(data) if callback_fn is not None: pred = {**callback_fn(pred, data), **pred} if keys != "*": if len(set(keys) - set(pred.keys())) > 0: raise ValueError(f"Missing key {set(keys) - set(pred.keys())}") pred = {k: v for k, v in pred.items() if k in keys + optional_keys} assert len(pred) > 0 # renormalization for k in pred.keys(): if k.startswith("keypoints"): idx = k.replace("keypoints", "") scales = 1.0 / ( data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] ) pred[k] = pred[k] * scales[None] if k.startswith("lines"): idx = k.replace("lines", "") scales = 1.0 / ( data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] ) pred[k] = pred[k] * scales[None] if k.startswith("orig_lines"): idx = k.replace("orig_lines", "") scales = 1.0 / ( data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] ) pred[k] = pred[k] * scales[None] pred = {k: v[0].cpu().numpy() for k, v in pred.items()} if as_half: for k in pred: dt = pred[k].dtype if (dt == np.float32) and (dt != np.float16): pred[k] = pred[k].astype(np.float16) try: name = data["name"][0] grp = hfile.create_group(name) for k, v in pred.items(): grp.create_dataset(k, data=v) except RuntimeError: continue del pred hfile.close() return output_file