Spaces:
Runtime error
Runtime error
File size: 2,119 Bytes
6672bfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import numpy as np
from label_colors import colorMap
from PIL import Image
from spade.model import Pix2PixModel
from spade.dataset import get_transform
from torchvision.transforms import ToPILImage
'''colors = np.array([[56, 79, 131], [239, 239, 239],
[93, 110, 50], [183, 210, 78],
[60, 59, 75], [250, 250, 250]])'''
colors = [key['color'] for key in colorMap]
id_list = [key['id'] for key in colorMap]
def semantic(img):
print("semantic", type(img))
h, w = img.size
imrgb = img.convert("RGB")
pix = list(imrgb.getdata())
mask = [id_list[colors.index(i)] if i in colors else 156 for i in pix]
return np.array(mask).reshape(h, w)
def evaluate(labelmap):
opt = {
'label_nc': 182, # num classes in coco model
'crop_size': 512,
'load_size': 512,
'aspect_ratio': 1.0,
'isTrain': False,
'checkpoints_dir': 'app',
'which_epoch': 'latest',
'use_gpu': False
}
model = Pix2PixModel(opt)
model.eval()
image = Image.fromarray(np.array(labelmap).astype(np.uint8))
transform_label = get_transform(opt, method=Image.NEAREST, normalize=False)
# transforms.ToTensor in transform_label rescales image from [0,255] to [0.0,1.0]
# lets rescale it back to [0,255] to match our label ids
label_tensor = transform_label(image) * 255.0
label_tensor[label_tensor == 255] = opt['label_nc'] # 'unknown' is opt.label_nc
print("label_tensor:", label_tensor.shape)
# not using encoder, so creating a blank image...
transform_image = get_transform(opt)
image_tensor = transform_image(Image.new('RGB', (500, 500)))
data = {
'label': label_tensor.unsqueeze(0),
'instance': label_tensor.unsqueeze(0),
'image': image_tensor.unsqueeze(0)
}
generated = model(data, mode='inference')
print("generated_image:", generated.shape)
return generated
def to_image(generated):
to_img = ToPILImage()
normalized_img = ((generated.reshape([3, 512, 512]) + 1) / 2.0) * 255.0
return to_img(normalized_img.byte().cpu())
|