artelabsuper
add local normalization
459c031
raw history blame
No virus
2.68 kB
import gradio as gr
from PIL import Image
import torchvision
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC
# load model
modeltype2path = {
'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
'ModelB': 'DTM_exp_train10%_model_b/g-best.pth',
'ModelC': 'DTM_exp_train10%_model_c/g-best.pth',
}
DEVICE='cpu'
MODELS_TYPE = list(modeltype2path.keys())
generators = [GA(), GB(), GC()]
for i in range(len(generators)):
generators[i] = torch.nn.DataParallel(generators[i])
state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu'))
generators[i].load_state_dict(state_dict)
generators[i] = generators[i].module.to(DEVICE)
generators[i].eval()
preprocess = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
])
def predict(input_image, model_name):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# transform image to torch and do preprocessing
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
# model predict
with torch.no_grad():
output = generators[MODELS_TYPE.index(model_name)](torch_img)
sr, sr_dem_selected = output[0], output[1]
# transform torch to image
sr = sr.squeeze(0).cpu()
torchvision.utils.save_image(sr, 'sr_pred.png')
sr = np.array(Image.open('sr_pred.png'))
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
fig, ax = plt.subplots()
im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar(im, ax=ax)
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
# return correct image and info
info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters"
return info, sr, data
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(shape=(512,512)),
gr.inputs.Radio(MODELS_TYPE)
],
outputs=[
gr.Text(),
gr.Image(),
gr.Image()
],
examples=[
["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
],
title="Super Resolution and DTM Estimation",
description="This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)."
)
iface.launch()