import torch import torchvision from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np from models.modelNetA import Generator as GA from models.modelNetB import Generator as GB from models.modelNetC import Generator as GC # DEVICE='cpu' DEVICE='cuda' model_type = 'model_c' modeltype2path = { 'model_a': 'DTM_exp_train10%_model_a/g-best.pth', 'model_b': 'DTM_exp_train10%_model_b/g-best.pth', 'model_c': 'DTM_exp_train10%_model_c/g-best.pth', } if model_type == 'model_a': generator = GA() if model_type == 'model_b': generator = GB() if model_type == 'model_c': generator = GC() generator = torch.nn.DataParallel(generator) state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu')) generator.load_state_dict(state_dict_Gen) generator = generator.module.to(DEVICE) # generator.to(DEVICE) generator.eval() preprocess = transforms.Compose([ transforms.Grayscale(), # transforms.Resize((128, 128)), transforms.ToTensor() ]) input_img = Image.open('demo_imgs/fake.jpg') torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE) torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) with torch.no_grad(): output = generator(torch_img) sr, sr_dem_selected = output[0], output[1] sr = sr.squeeze(0).cpu() print(sr.shape) torchvision.utils.save_image(sr, 'sr.png') # sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L') # sr.save('sr2.png') sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() print(sr_dem_selected.shape) plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) plt.colorbar() plt.savefig('test.png')