File size: 1,141 Bytes
b4eade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from collections import OrderedDict
import torch
from models.model import GLPDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

DEVICE='cpu'

def load_mde_model(path):
    model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE)
    model_weight = torch.load(path, map_location=torch.device('cpu'))
    model_weight = model_weight['model_state_dict']
    if 'module' in next(iter(model_weight.items()))[0]:
        model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
    model.load_state_dict(model_weight)
    model.eval()
    return model

model = load_mde_model('best_model.ckpt')
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
]) 

input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0)
with torch.no_grad():
    output_patch = model(torch_img)
output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy()
print(output_patch.shape)

plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch))
plt.colorbar()
plt.savefig('test.png')