import gradio as gr from PIL import Image from collections import OrderedDict import torch from models.model import GLPDepth from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import numpy as np import os # load model DEVICE='cpu' def load_mde_model(path): model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) model_weight = torch.load(path, map_location=torch.device('cpu')) model_weight = model_weight['model_state_dict'] if 'module' in next(iter(model_weight.items()))[0]: model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) model.load_state_dict(model_weight) model.eval() return model model = load_mde_model('best_model.ckpt') preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) def predict(input_image): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # transform image to torch and do preprocessing torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0) # model predict with torch.no_grad(): output_patch = model(torch_img) # transform torch to image predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy() # return correct image fig, ax = plt.subplots() im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image)) plt.colorbar(im, ax=ax) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data #, str(predicted_image.tolist()) iface = gr.Interface( fn=predict, inputs=gr.Image(shape=(512,512)), outputs=[ gr.Image(shape=(512,512)), # gr.outputs.Textbox(label='Raw output') ], examples=[ [f"demo_imgs/{name}"] for name in os.listdir('demo_imgs') ], title="DTM Estimation", description="This demo predict a DTM using GLP Depth model. It will scale input image to 512x512 and at the end it will apply a colormap to better visualize the output." ) iface.launch()