Spaces:
Runtime error
Runtime error
File size: 2,940 Bytes
21bf7d6 4216279 21bf7d6 66432b9 21bf7d6 66432b9 21bf7d6 d8f83ab 053b94b 21bf7d6 66432b9 21bf7d6 053b94b 21bf7d6 053b94b 21bf7d6 66432b9 459c031 21bf7d6 66432b9 21bf7d6 66432b9 7f268fe 21bf7d6 1513566 053b94b 1513566 66432b9 4216279 66432b9 21bf7d6 053b94b 21bf7d6 63cfb07 053b94b 21bf7d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import os
from PIL import Image
import torchvision
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC
scale_size = 128
scale_sizes = [128, 256, 512]
# load model
modeltype2path = {
'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
'ModelB': 'DTM_exp_train10%_model_b/g-best.pth',
'ModelC': 'DTM_exp_train10%_model_c/g-best.pth',
}
DEVICE='cpu'
MODELS_TYPE = list(modeltype2path.keys())
generators = [GA(), GB(), GC()]
for i in range(len(generators)):
generators[i] = torch.nn.DataParallel(generators[i])
state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu'))
generators[i].load_state_dict(state_dict)
generators[i] = generators[i].module.to(DEVICE)
generators[i].eval()
preprocess = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
])
def predict(input_image, model_name, input_scale_factor):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image)
# transform image to torch and do preprocessing
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
# model predict
with torch.no_grad():
output = generators[MODELS_TYPE.index(model_name)](torch_img)
sr, sr_dem_selected = output[0], output[1]
# transform torch to image
sr = sr.squeeze(0).cpu()
torchvision.utils.save_image(sr, 'sr_pred.png')
sr = np.array(Image.open('sr_pred.png'))
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
fig, ax = plt.subplots()
im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar(im, ax=ax)
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
# return correct image and info
info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters"
return info, sr, data
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(),
gr.inputs.Radio(MODELS_TYPE),
gr.inputs.Radio(scale_sizes)
],
outputs=[
gr.Text(label='Model info'),
gr.Image(label='Super Resolution'),
gr.Image(label='DTM')
],
examples=[
[f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs')
],
title="Super Resolution and DTM Estimation",
description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)."
)
iface.launch() |