File size: 1,808 Bytes
054eef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a07d42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import numpy as np

import sys
sys.path.insert(0,'../.')


from utils import Generator as netG
import torch.nn as nn
import torch
import PIL.Image as Image
import PIL
from torchvision import transforms as tfs
from skimage import measure

import matplotlib
matplotlib.use('Agg')
available_device = 'cpu'

#transformation initial

transformations = [
                   tfs.Grayscale(),
                   tfs.Resize((128, 128)),
                   tfs.Lambda(lambda x: PIL.ImageOps.invert(x)),
                   tfs.ToTensor()
]

trans = tfs.Compose(transformations)

# Model Initial

model_G = netG.Generator(nc_input=2, nc_output=1).to(available_device)
checkpoint = torch.load("generador_v9_current_5000.pkl", map_location = torch.device(available_device))
model_G.load_state_dict(checkpoint)
model_G = model_G.eval()

def sketch_recognition(input_img):
    
    img = input_img['mask'][:,:,[0,1,2]]
    img[img!=0] = 1
    img = np.abs(img - 1)
    input_img = Image.fromarray(np.uint8(img)).convert('L')
    
    image = trans(input_img).reshape(-1, 1, 128, 128)

    image = image.to(available_device)
    blacks = torch.zeros_like(image).to(available_device)
    origin_A = torch.cat((image, blacks), 1)
    predicted_A = model_G(origin_A)
    predicted_A = predicted_A + image
    result = predicted_A.detach().cpu().numpy().reshape(128, 128)
    result = result * 255
    img = np.zeros_like(result) 
    mean = np.mean(result)
    img[result>mean] = 255
    img[img==0] = -255
    img[img==255] = 0
    img = np.abs(img)
    
    return np.array(img, dtype=np.int32)



img = gr.Image(tool="sketch", source="upload", label="Mask", value="CC_02_7.png", invert_colors=True, shape=(128, 128))
            

gr.Interface(fn=sketch_recognition, inputs=img, outputs="image").launch()