artelabsuper
remain scaling commented on test script
ad62063
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC
# DEVICE='cpu'
DEVICE='cuda'
model_type = 'model_c'
modeltype2path = {
'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
}
if model_type == 'model_a':
generator = GA()
if model_type == 'model_b':
generator = GB()
if model_type == 'model_c':
generator = GC()
generator = torch.nn.DataParallel(generator)
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
generator.load_state_dict(state_dict_Gen)
generator = generator.module.to(DEVICE)
# generator.to(DEVICE)
generator.eval()
preprocess = transforms.Compose([
transforms.Grayscale(),
# transforms.Resize((128, 128)),
transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
with torch.no_grad():
output = generator(torch_img)
sr, sr_dem_selected = output[0], output[1]
sr = sr.squeeze(0).cpu()
print(sr.shape)
torchvision.utils.save_image(sr, 'sr.png')
# sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L')
# sr.save('sr2.png')
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
print(sr_dem_selected.shape)
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar()
plt.savefig('test.png')