Spaces:
Runtime error
Runtime error
File size: 1,141 Bytes
b4eade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from collections import OrderedDict
import torch
from models.model import GLPDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
DEVICE='cpu'
def load_mde_model(path):
model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE)
model_weight = torch.load(path, map_location=torch.device('cpu'))
model_weight = model_weight['model_state_dict']
if 'module' in next(iter(model_weight.items()))[0]:
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
model.load_state_dict(model_weight)
model.eval()
return model
model = load_mde_model('best_model.ckpt')
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0)
with torch.no_grad():
output_patch = model(torch_img)
output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy()
print(output_patch.shape)
plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch))
plt.colorbar()
plt.savefig('test.png') |