Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from PIL import Image | |
import torchvision | |
from torchvision import transforms | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from models.modelNetA import Generator as GA | |
from models.modelNetB import Generator as GB | |
from models.modelNetC import Generator as GC | |
scale_size = 128 | |
scale_sizes = [128, 256, 512] | |
# load model | |
modeltype2path = { | |
'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', | |
'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', | |
'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', | |
} | |
DEVICE='cpu' | |
MODELS_TYPE = list(modeltype2path.keys()) | |
generators = [GA(), GB(), GC()] | |
for i in range(len(generators)): | |
generators[i] = torch.nn.DataParallel(generators[i]) | |
state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) | |
generators[i].load_state_dict(state_dict) | |
generators[i] = generators[i].module.to(DEVICE) | |
generators[i].eval() | |
preprocess = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.ToTensor() | |
]) | |
def predict(input_image, model_name, input_scale_factor): | |
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image) | |
# transform image to torch and do preprocessing | |
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) | |
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) | |
# model predict | |
with torch.no_grad(): | |
output = generators[MODELS_TYPE.index(model_name)](torch_img) | |
sr, sr_dem_selected = output[0], output[1] | |
# transform torch to image | |
sr = sr.squeeze(0).cpu() | |
torchvision.utils.save_image(sr, 'sr_pred.png') | |
sr = np.array(Image.open('sr_pred.png')) | |
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() | |
fig, ax = plt.subplots() | |
im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) | |
plt.colorbar(im, ax=ax) | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
# return correct image and info | |
info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters" | |
return info, sr, data | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(), | |
gr.inputs.Radio(MODELS_TYPE), | |
gr.inputs.Radio(scale_sizes) | |
], | |
outputs=[ | |
gr.Text(label='Model info'), | |
gr.Image(label='Super Resolution'), | |
gr.Image(label='DTM') | |
], | |
examples=[ | |
[f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs') | |
], | |
title="Super Resolution and DTM Estimation", | |
description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)." | |
) | |
iface.launch() |