iberianGAN / app.py
pablo1n7's picture
Upload app.py
054eef0
raw history blame
No virus
1.84 kB
import gradio as gr
import numpy as np
import sys
sys.path.insert(0,'../.')
from utils import Generator as netG
import torch.nn as nn
import torch
import PIL.Image as Image
import PIL
import matplotlib.pyplot as plt
from torchvision import transforms as tfs
from skimage import measure
import matplotlib
matplotlib.use('Agg')
available_device = 'cpu'
#transformation initial
transformations = [
tfs.Grayscale(),
tfs.Resize((128, 128)),
tfs.Lambda(lambda x: PIL.ImageOps.invert(x)),
tfs.ToTensor()
]
trans = tfs.Compose(transformations)
# Model Initial
model_G = netG.Generator(nc_input=2, nc_output=1).to(available_device)
checkpoint = torch.load("generador_v9_current_5000.pkl", map_location = torch.device(available_device))
model_G.load_state_dict(checkpoint)
model_G = model_G.eval()
def sketch_recognition(input_img):
img = input_img['mask'][:,:,[0,1,2]]
img[img!=0] = 1
img = np.abs(img - 1)
input_img = Image.fromarray(np.uint8(img)).convert('L')
image = trans(input_img).reshape(-1, 1, 128, 128)
image = image.to(available_device)
blacks = torch.zeros_like(image).to(available_device)
origin_A = torch.cat((image, blacks), 1)
predicted_A = model_G(origin_A)
predicted_A = predicted_A + image
result = predicted_A.detach().cpu().numpy().reshape(128, 128)
result = result * 255
img = np.zeros_like(result)
mean = np.mean(result)
img[result>mean] = 255
img[img==0] = -255
img[img==255] = 0
img = np.abs(img)
return np.array(img, dtype=np.int32)
img = gr.Image(tool="sketch", source="upload", label="Mask", value="CC_02_7.png", invert_colors=True, shape=(128, 128))
gr.Interface(fn=sketch_recognition, inputs=img, outputs="image").launch()