Spaces:
Runtime error
Runtime error
File size: 1,808 Bytes
054eef0 1a07d42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
import numpy as np
import sys
sys.path.insert(0,'../.')
from utils import Generator as netG
import torch.nn as nn
import torch
import PIL.Image as Image
import PIL
from torchvision import transforms as tfs
from skimage import measure
import matplotlib
matplotlib.use('Agg')
available_device = 'cpu'
#transformation initial
transformations = [
tfs.Grayscale(),
tfs.Resize((128, 128)),
tfs.Lambda(lambda x: PIL.ImageOps.invert(x)),
tfs.ToTensor()
]
trans = tfs.Compose(transformations)
# Model Initial
model_G = netG.Generator(nc_input=2, nc_output=1).to(available_device)
checkpoint = torch.load("generador_v9_current_5000.pkl", map_location = torch.device(available_device))
model_G.load_state_dict(checkpoint)
model_G = model_G.eval()
def sketch_recognition(input_img):
img = input_img['mask'][:,:,[0,1,2]]
img[img!=0] = 1
img = np.abs(img - 1)
input_img = Image.fromarray(np.uint8(img)).convert('L')
image = trans(input_img).reshape(-1, 1, 128, 128)
image = image.to(available_device)
blacks = torch.zeros_like(image).to(available_device)
origin_A = torch.cat((image, blacks), 1)
predicted_A = model_G(origin_A)
predicted_A = predicted_A + image
result = predicted_A.detach().cpu().numpy().reshape(128, 128)
result = result * 255
img = np.zeros_like(result)
mean = np.mean(result)
img[result>mean] = 255
img[img==0] = -255
img[img==255] = 0
img = np.abs(img)
return np.array(img, dtype=np.int32)
img = gr.Image(tool="sketch", source="upload", label="Mask", value="CC_02_7.png", invert_colors=True, shape=(128, 128))
gr.Interface(fn=sketch_recognition, inputs=img, outputs="image").launch()
|