File size: 1,742 Bytes
eba1c6b
 
 
 
 
 
 
 
 
 
 
 
ad62063
7f268fe
 
eba1c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad62063
eba1c6b
 
 
 
459c031
eba1c6b
 
 
 
 
 
 
66432b9
 
eba1c6b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC



# DEVICE='cpu'
DEVICE='cuda'
model_type = 'model_c'

modeltype2path = {
    'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
    'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
    'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
}

if model_type == 'model_a':
    generator = GA()
if model_type == 'model_b':
    generator = GB()
if model_type == 'model_c':
    generator = GC()

generator = torch.nn.DataParallel(generator)
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
generator.load_state_dict(state_dict_Gen)
generator = generator.module.to(DEVICE)
# generator.to(DEVICE)
generator.eval()

preprocess = transforms.Compose([
    transforms.Grayscale(),
    # transforms.Resize((128, 128)),
    transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
with torch.no_grad():
    output = generator(torch_img)
sr, sr_dem_selected = output[0], output[1]
sr = sr.squeeze(0).cpu()

print(sr.shape)
torchvision.utils.save_image(sr, 'sr.png')
# sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L')
# sr.save('sr2.png')

sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
print(sr_dem_selected.shape)
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar()
plt.savefig('test.png')