Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Shariq Farooq Bhat | |
import argparse | |
from pprint import pprint | |
import torch | |
from zoedepth.utils.easydict import EasyDict as edict | |
from tqdm import tqdm | |
from zoedepth.data.data_mono import DepthDataLoader | |
from zoedepth.models.builder import build_model | |
from zoedepth.utils.arg_utils import parse_unknown | |
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR | |
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics, | |
count_parameters) | |
def infer(model, images, **kwargs): | |
"""Inference with flip augmentation""" | |
# images.shape = N, C, H, W | |
def get_depth_from_prediction(pred): | |
if isinstance(pred, torch.Tensor): | |
pred = pred # pass | |
elif isinstance(pred, (list, tuple)): | |
pred = pred[-1] | |
elif isinstance(pred, dict): | |
pred = pred['metric_depth'] if 'metric_depth' in pred else pred['out'] | |
else: | |
raise NotImplementedError(f"Unknown output type {type(pred)}") | |
return pred | |
pred1 = model(images, **kwargs) | |
pred1 = get_depth_from_prediction(pred1) | |
pred2 = model(torch.flip(images, [3]), **kwargs) | |
pred2 = get_depth_from_prediction(pred2) | |
pred2 = torch.flip(pred2, [3]) | |
mean_pred = 0.5 * (pred1 + pred2) | |
return mean_pred | |
def evaluate(model, test_loader, config, round_vals=True, round_precision=3): | |
model.eval() | |
metrics = RunningAverageDict() | |
for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)): | |
if 'has_valid_depth' in sample: | |
if not sample['has_valid_depth']: | |
continue | |
image, depth = sample['image'], sample['depth'] | |
image, depth = image.cuda(), depth.cuda() | |
depth = depth.squeeze().unsqueeze(0).unsqueeze(0) | |
focal = sample.get('focal', torch.Tensor( | |
[715.0873]).cuda()) # This magic number (focal) is only used for evaluating BTS model | |
pred = infer(model, image, dataset=sample['dataset'][0], focal=focal) | |
# Save image, depth, pred for visualization | |
if "save_images" in config and config.save_images: | |
import os | |
# print("Saving images ...") | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from zoedepth.utils.misc import colorize | |
os.makedirs(config.save_images, exist_ok=True) | |
# def save_image(img, path): | |
d = colorize(depth.squeeze().cpu().numpy(), 0, 10) | |
p = colorize(pred.squeeze().cpu().numpy(), 0, 10) | |
im = transforms.ToPILImage()(image.squeeze().cpu()) | |
im.save(os.path.join(config.save_images, f"{i}_img.png")) | |
Image.fromarray(d).save(os.path.join(config.save_images, f"{i}_depth.png")) | |
Image.fromarray(p).save(os.path.join(config.save_images, f"{i}_pred.png")) | |
# print(depth.shape, pred.shape) | |
metrics.update(compute_metrics(depth, pred, config=config)) | |
if round_vals: | |
def r(m): return round(m, round_precision) | |
else: | |
def r(m): return m | |
metrics = {k: r(v) for k, v in metrics.get_value().items()} | |
return metrics | |
def main(config): | |
model = build_model(config) | |
test_loader = DepthDataLoader(config, 'online_eval').data | |
model = model.cuda() | |
metrics = evaluate(model, test_loader, config) | |
print(f"{colors.fg.green}") | |
print(metrics) | |
print(f"{colors.reset}") | |
metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M" | |
return metrics | |
def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs): | |
# Load default pretrained resource defined in config if not set | |
overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs | |
config = get_config(model_name, "eval", dataset, **overwrite) | |
# config = change_dataset(config, dataset) # change the dataset | |
pprint(config) | |
print(f"Evaluating {model_name} on {dataset}...") | |
metrics = main(config) | |
return metrics | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, | |
required=True, help="Name of the model to evaluate") | |
parser.add_argument("-p", "--pretrained_resource", type=str, | |
required=False, default=None, help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used, Refer models.model_io.load_state_from_resource for more details.") | |
parser.add_argument("-d", "--dataset", type=str, required=False, | |
default='nyu', help="Dataset to evaluate on") | |
args, unknown_args = parser.parse_known_args() | |
overwrite_kwargs = parse_unknown(unknown_args) | |
if "ALL_INDOOR" in args.dataset: | |
datasets = ALL_INDOOR | |
elif "ALL_OUTDOOR" in args.dataset: | |
datasets = ALL_OUTDOOR | |
elif "ALL" in args.dataset: | |
datasets = ALL_EVAL_DATASETS | |
elif "," in args.dataset: | |
datasets = args.dataset.split(",") | |
else: | |
datasets = [args.dataset] | |
for dataset in datasets: | |
eval_model(args.model, pretrained_resource=args.pretrained_resource, | |
dataset=dataset, **overwrite_kwargs) | |