Donmill commited on
Commit
bb42532
1 Parent(s): bab6abe
Files changed (1) hide show
  1. test.py +0 -63
test.py DELETED
@@ -1,63 +0,0 @@
1
- import numpy as np
2
- from label_colors import colorMap
3
- from PIL import Image
4
- from spade.model import Pix2PixModel
5
- from spade.dataset import get_transform
6
- from torchvision.transforms import ToPILImage
7
-
8
- '''colors = np.array([[56, 79, 131], [239, 239, 239],
9
- [93, 110, 50], [183, 210, 78],
10
- [60, 59, 75], [250, 250, 250]])'''
11
- colors = [key['color'] for key in colorMap]
12
- id_list = [key['id'] for key in colorMap]
13
-
14
-
15
- def semantic(img):
16
- print("semantic", type(img))
17
- h, w = img.size
18
- imrgb = img.convert("RGB")
19
- pix = list(imrgb.getdata())
20
- mask = [id_list[colors.index(i)] if i in colors else 156 for i in pix]
21
- return np.array(mask).reshape(h, w)
22
-
23
-
24
- def evaluate(labelmap):
25
- opt = {
26
- 'label_nc': 182, # num classes in coco model
27
- 'crop_size': 512,
28
- 'load_size': 512,
29
- 'aspect_ratio': 1.0,
30
- 'isTrain': False,
31
- 'checkpoints_dir': 'app',
32
- 'which_epoch': 'latest',
33
- 'use_gpu': False
34
- }
35
- model = Pix2PixModel(opt)
36
- model.eval()
37
- image = Image.fromarray(np.array(labelmap).astype(np.uint8))
38
- transform_label = get_transform(opt, method=Image.NEAREST, normalize=False)
39
- # transforms.ToTensor in transform_label rescales image from [0,255] to [0.0,1.0]
40
- # lets rescale it back to [0,255] to match our label ids
41
- label_tensor = transform_label(image) * 255.0
42
- label_tensor[label_tensor == 255] = opt['label_nc'] # 'unknown' is opt.label_nc
43
- print("label_tensor:", label_tensor.shape)
44
-
45
- # not using encoder, so creating a blank image...
46
- transform_image = get_transform(opt)
47
- image_tensor = transform_image(Image.new('RGB', (500, 500)))
48
-
49
- data = {
50
- 'label': label_tensor.unsqueeze(0),
51
- 'instance': label_tensor.unsqueeze(0),
52
- 'image': image_tensor.unsqueeze(0)
53
- }
54
- generated = model(data, mode='inference')
55
- print("generated_image:", generated.shape)
56
-
57
- return generated
58
-
59
-
60
- def to_image(generated):
61
- to_img = ToPILImage()
62
- normalized_img = ((generated.reshape([3, 512, 512]) + 1) / 2.0) * 255.0
63
- return to_img(normalized_img.byte().cpu())