import gradio as gr from PIL import Image import torchvision from torchvision import transforms import torch import matplotlib.pyplot as plt import numpy as np from models.modelNetA import Generator as GA from models.modelNetB import Generator as GB from models.modelNetC import Generator as GC # load model modeltype2path = { 'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', 'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', 'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', } DEVICE='cpu' MODELS_TYPE = list(modeltype2path.keys()) generators = [GA(), GB(), GC()] for i in range(len(generators)): generators[i] = torch.nn.DataParallel(generators[i]) state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) generators[i].load_state_dict(state_dict) generators[i] = generators[i].module.to(DEVICE) generators[i].eval() preprocess = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor() ]) def predict(input_image, model_name): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # transform image to torch and do preprocessing torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) # model predict with torch.no_grad(): output = generators[MODELS_TYPE.index(model_name)](torch_img) sr, sr_dem_selected = output[0], output[1] # transform torch to image sr = sr.squeeze(0).cpu() torchvision.utils.save_image(sr, 'sr_pred.png') sr = np.array(Image.open('sr_pred.png')) sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() fig, ax = plt.subplots() im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) plt.colorbar(im, ax=ax) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) # return correct image return sr, data iface = gr.Interface( fn=predict, inputs=[ gr.Image(shape=(512,512)), gr.inputs.Radio(MODELS_TYPE) ], outputs=[ gr.Image(), gr.Image() ], examples=[ ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image ], title="DTM Estimation", description="This demo predict a DTM..." ) iface.launch()