DTM_Estimation / test.py
artelabsuper
demo with model
b4eade4
from collections import OrderedDict
import torch
from models.model import GLPDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
DEVICE='cpu'
def load_mde_model(path):
model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE)
model_weight = torch.load(path, map_location=torch.device('cpu'))
model_weight = model_weight['model_state_dict']
if 'module' in next(iter(model_weight.items()))[0]:
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
model.load_state_dict(model_weight)
model.eval()
return model
model = load_mde_model('best_model.ckpt')
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0)
with torch.no_grad():
output_patch = model(torch_img)
output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy()
print(output_patch.shape)
plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch))
plt.colorbar()
plt.savefig('test.png')