import glob import os import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms from tqdm import tqdm import model_io import utils from adabins import UnetAdaptiveBins def _is_pil_image(img): return isinstance(img, Image.Image) def _is_numpy_image(img): return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) class ToTensor(object): def __init__(self): self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def __call__(self, image, target_size=(640, 480)): # image = image.resize(target_size) image = self.to_tensor(image) image = self.normalize(image) return image def to_tensor(self, pic): if not (_is_pil_image(pic) or _is_numpy_image(pic)): raise TypeError( 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if isinstance(pic, np.ndarray): img = torch.from_numpy(pic.transpose((2, 0, 1))) return img # handle PIL Image if pic.mode == 'I': img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16': img = torch.from_numpy(np.array(pic, np.int16, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16': nchannel = 1 else: nchannel = len(pic.mode) img = img.view(pic.size[1], pic.size[0], nchannel) img = img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor): return img.float() else: return img class InferenceHelper: def __init__(self, models_path, dataset='nyu', device='cuda:0'): self.toTensor = ToTensor() self.device = device if dataset == 'nyu': self.min_depth = 1e-3 self.max_depth = 10 self.saving_factor = 1000 # used to save in 16 bit model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) pretrained_path = os.path.join(models_path, "AdaBins_nyu.pt") elif dataset == 'kitti': self.min_depth = 1e-3 self.max_depth = 80 self.saving_factor = 256 model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) pretrained_path = os.path.join(models_path, "AdaBins_kitti.pt") else: raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) model, _, _ = model_io.load_checkpoint(pretrained_path, model) model.eval() self.model = model.to(self.device) @torch.no_grad() def predict_pil(self, pil_image, visualized=False): # pil_image = pil_image.resize((640, 480)) img = np.asarray(pil_image) / 255. img = self.toTensor(img).unsqueeze(0).float().to(self.device) bin_centers, pred = self.predict(img) if visualized: viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') # pred = np.asarray(pred*1000, dtype='uint16') viz = Image.fromarray(viz) return bin_centers, pred, viz return bin_centers, pred @torch.no_grad() def predict(self, image): bins, pred = self.model(image) pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) # Flip image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) pred_lr = self.model(image)[-1] pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) # Take average of original and mirror final = 0.5 * (pred + pred_lr) final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], mode='bilinear', align_corners=True).cpu().numpy() final[final < self.min_depth] = self.min_depth final[final > self.max_depth] = self.max_depth final[np.isinf(final)] = self.max_depth final[np.isnan(final)] = self.min_depth centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) centers = centers.cpu().squeeze().numpy() centers = centers[centers > self.min_depth] centers = centers[centers < self.max_depth] return centers, final @torch.no_grad() def predict_dir(self, test_dir, out_dir): os.makedirs(out_dir, exist_ok=True) transform = ToTensor() all_files = glob.glob(os.path.join(test_dir, "*")) self.model.eval() for f in tqdm(all_files): image = np.asarray(Image.open(f), dtype='float32') / 255. image = transform(image).unsqueeze(0).to(self.device) centers, final = self.predict(image) # final = final.squeeze().cpu().numpy() final = (final * self.saving_factor).astype('uint16') basename = os.path.basename(f).split('.')[0] save_path = os.path.join(out_dir, basename + ".png") Image.fromarray(final.squeeze()).save(save_path) def to(self, device): self.device = device self.model.to(device) if __name__ == '__main__': import matplotlib.pyplot as plt from time import time img = Image.open("test_imgs/classroom__rgb_00283.jpg") start = time() inferHelper = InferenceHelper() centers, pred = inferHelper.predict_pil(img) print(f"took :{time() - start}s") plt.imshow(pred.squeeze(), cmap='magma_r') plt.show()