import gradio as gr from PIL import Image import torchvision from torchvision import transforms import torch import matplotlib.pyplot as plt import numpy as np from models.modelNetA import Generator as GA from models.modelNetB import Generator as GB from models.modelNetC import Generator as GC # load model modeltype2path = { 'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', 'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', 'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', } DEVICE='cpu' MODELS_TYPE = list(modeltype2path.keys()) generators = [GA(), GB(), GC()] for i in range(len(generators)): generators[i] = torch.nn.DataParallel(generators[i]) state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) generators[i].load_state_dict(state_dict) generators[i] = generators[i].module.to(DEVICE) generators[i].eval() preprocess = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor() ]) def predict(input_image, model_name): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # transform image to torch and do preprocessing torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) # model predict with torch.no_grad(): output = generators[MODELS_TYPE.index(model_name)](torch_img) sr, sr_dem_selected = output[0], output[1] # transform torch to image sr = sr.squeeze(0).cpu() torchvision.utils.save_image(sr, 'sr_pred.png') sr = np.array(Image.open('sr_pred.png')) sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() fig, ax = plt.subplots() im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) plt.colorbar(im, ax=ax) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) # return correct image and info info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters" return info, sr, data iface = gr.Interface( fn=predict, inputs=[ gr.Image(shape=(512,512)), gr.inputs.Radio(MODELS_TYPE) ], outputs=[ gr.Text(), gr.Image(), gr.Image() ], examples=[ ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image ], title="Super Resolution and DTM Estimation", description="This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)." ) iface.launch()