Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import sys | |
sys.path.insert(0,'../.') | |
from utils import Generator as netG | |
import torch.nn as nn | |
import torch | |
import PIL.Image as Image | |
import PIL | |
from torchvision import transforms as tfs | |
from skimage import measure | |
import matplotlib | |
matplotlib.use('Agg') | |
available_device = 'cpu' | |
#transformation initial | |
transformations = [ | |
tfs.Grayscale(), | |
tfs.Resize((128, 128)), | |
tfs.Lambda(lambda x: PIL.ImageOps.invert(x)), | |
tfs.ToTensor() | |
] | |
trans = tfs.Compose(transformations) | |
# Model Initial | |
model_G = netG.Generator(nc_input=2, nc_output=1).to(available_device) | |
checkpoint = torch.load("generador_v9_current_5000.pkl", map_location = torch.device(available_device)) | |
model_G.load_state_dict(checkpoint) | |
model_G = model_G.eval() | |
def sketch_recognition(input_img): | |
img = input_img['mask'][:,:,[0,1,2]] | |
img[img!=0] = 1 | |
img = np.abs(img - 1) | |
input_img = Image.fromarray(np.uint8(img)).convert('L') | |
image = trans(input_img).reshape(-1, 1, 128, 128) | |
image = image.to(available_device) | |
blacks = torch.zeros_like(image).to(available_device) | |
origin_A = torch.cat((image, blacks), 1) | |
predicted_A = model_G(origin_A) | |
predicted_A = predicted_A + image | |
result = predicted_A.detach().cpu().numpy().reshape(128, 128) | |
result = result * 255 | |
img = np.zeros_like(result) | |
mean = np.mean(result) | |
img[result>mean] = 255 | |
img[img==0] = -255 | |
img[img==255] = 0 | |
img = np.abs(img) | |
return np.array(img, dtype=np.int32) | |
img = gr.Image(tool="sketch", source="upload", label="Mask", value="CC_02_7.png", invert_colors=True, shape=(128, 128)) | |
gr.Interface(fn=sketch_recognition, inputs=img, outputs="image").launch() | |