import gradio as gr import os from PIL import Image import torchvision from torchvision import transforms import torch import matplotlib.pyplot as plt import numpy as np from models.modelNetA import Generator as GA from models.modelNetB import Generator as GB from models.modelNetC import Generator as GC scale_size = 128 scale_sizes = [128, 256, 512] # load model modeltype2path = { 'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', 'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', 'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', } DEVICE='cpu' MODELS_TYPE = list(modeltype2path.keys()) generators = [GA(), GB(), GC()] for i in range(len(generators)): generators[i] = torch.nn.DataParallel(generators[i]) state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) generators[i].load_state_dict(state_dict) generators[i] = generators[i].module.to(DEVICE) generators[i].eval() preprocess = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor() ]) def predict(input_image, model_name, input_scale_factor): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image) # transform image to torch and do preprocessing torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) # model predict with torch.no_grad(): output = generators[MODELS_TYPE.index(model_name)](torch_img) sr, sr_dem_selected = output[0], output[1] # transform torch to image sr = sr.squeeze(0).cpu() torchvision.utils.save_image(sr, 'sr_pred.png') sr = np.array(Image.open('sr_pred.png')) sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() fig, ax = plt.subplots() im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) plt.colorbar(im, ax=ax) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) # return correct image and info info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters" return info, sr, data iface = gr.Interface( fn=predict, inputs=[ gr.Image(), gr.inputs.Radio(MODELS_TYPE), gr.inputs.Radio(scale_sizes) ], outputs=[ gr.Text(label='Model info'), gr.Image(label='Super Resolution'), gr.Image(label='DTM') ], examples=[ [f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs') ], title="Super Resolution and DTM Estimation", description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)." ) iface.launch()