File size: 1,631 Bytes
eba1c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66432b9
eba1c6b
 
 
 
 
 
 
 
 
 
 
66432b9
 
eba1c6b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC



DEVICE='cpu'
model_type = 'model_b'

modeltype2path = {
    'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
    'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
    'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
}

if model_type == 'model_a':
    generator = GA()
if model_type == 'model_b':
    generator = GB()
if model_type == 'model_c':
    generator = GC()

generator = torch.nn.DataParallel(generator)
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
generator.load_state_dict(state_dict_Gen)
generator = generator.module.to(DEVICE)
# generator.to(DEVICE)
generator.eval()

preprocess = transforms.Compose([
    transforms.Grayscale(),
    # transforms.Resize((512, 512)),
    transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
with torch.no_grad():
    output = generator(torch_img)
sr, sr_dem_selected = output[0], output[1]
sr = sr.squeeze(0).cpu()

print(sr.shape)
torchvision.utils.save_image(sr, 'sr.png')
# sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L')
# sr.save('sr2.png')

sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
print(sr_dem_selected.shape)
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar()
plt.savefig('test.png')