import gradio as gr import numpy as np import sys sys.path.insert(0,'../.') from utils import Generator as netG import torch.nn as nn import torch import PIL.Image as Image import PIL from torchvision import transforms as tfs from skimage import measure import matplotlib matplotlib.use('Agg') available_device = 'cpu' #transformation initial transformations = [ tfs.Grayscale(), tfs.Resize((128, 128)), tfs.Lambda(lambda x: PIL.ImageOps.invert(x)), tfs.ToTensor() ] trans = tfs.Compose(transformations) # Model Initial model_G = netG.Generator(nc_input=2, nc_output=1).to(available_device) checkpoint = torch.load("generador_v9_current_5000.pkl", map_location = torch.device(available_device)) model_G.load_state_dict(checkpoint) model_G = model_G.eval() def sketch_recognition(input_img): img = input_img['mask'][:,:,[0,1,2]] img[img!=0] = 1 img = np.abs(img - 1) input_img = Image.fromarray(np.uint8(img)).convert('L') image = trans(input_img).reshape(-1, 1, 128, 128) image = image.to(available_device) blacks = torch.zeros_like(image).to(available_device) origin_A = torch.cat((image, blacks), 1) predicted_A = model_G(origin_A) predicted_A = predicted_A + image result = predicted_A.detach().cpu().numpy().reshape(128, 128) result = result * 255 img = np.zeros_like(result) mean = np.mean(result) img[result>mean] = 255 img[img==0] = -255 img[img==255] = 0 img = np.abs(img) return np.array(img, dtype=np.int32) img = gr.Image(tool="sketch", source="upload", label="Mask", value="CC_02_7.png", invert_colors=True, shape=(128, 128)) gr.Interface(fn=sketch_recognition, inputs=img, outputs="image").launch()