File size: 2,940 Bytes
21bf7d6
4216279
21bf7d6
 
66432b9
21bf7d6
66432b9
 
 
 
 
21bf7d6
d8f83ab
053b94b
21bf7d6
66432b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21bf7d6
053b94b
21bf7d6
053b94b
21bf7d6
66432b9
459c031
21bf7d6
66432b9
 
 
21bf7d6
66432b9
 
 
 
 
 
 
 
 
 
 
7f268fe
 
 
21bf7d6
 
 
1513566
053b94b
 
 
1513566
66432b9
4216279
 
 
66432b9
21bf7d6
053b94b
21bf7d6
63cfb07
053b94b
21bf7d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import os
from PIL import Image
import torchvision
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC

scale_size = 128
scale_sizes = [128, 256, 512]
# load model
modeltype2path = {
    'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
    'ModelB': 'DTM_exp_train10%_model_b/g-best.pth',
    'ModelC': 'DTM_exp_train10%_model_c/g-best.pth',
}
DEVICE='cpu'
MODELS_TYPE = list(modeltype2path.keys())
generators = [GA(), GB(), GC()]

for i in range(len(generators)):
    generators[i] = torch.nn.DataParallel(generators[i])
    state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu'))
    generators[i].load_state_dict(state_dict)
    generators[i] = generators[i].module.to(DEVICE)
    generators[i].eval()

preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

def predict(input_image, model_name, input_scale_factor):
    pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image)
    # transform image to torch and do preprocessing
    torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
    torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
    # model predict
    with torch.no_grad():
        output = generators[MODELS_TYPE.index(model_name)](torch_img)
    sr, sr_dem_selected = output[0], output[1]
    # transform torch to image
    sr = sr.squeeze(0).cpu()
    torchvision.utils.save_image(sr, 'sr_pred.png')
    sr = np.array(Image.open('sr_pred.png'))

    sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
    fig, ax = plt.subplots()
    im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
    plt.colorbar(im, ax=ax)
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    # return correct image and info
    info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters"
    return info, sr, data

iface = gr.Interface(
    fn=predict, 
    inputs=[
        gr.Image(),
        gr.inputs.Radio(MODELS_TYPE),
        gr.inputs.Radio(scale_sizes)
    ], 
    outputs=[
        gr.Text(label='Model info'),
        gr.Image(label='Super Resolution'),
        gr.Image(label='DTM')
    ],
    examples=[
        [f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs')
    ],
    title="Super Resolution and DTM Estimation",
    description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)."
)
iface.launch()