File size: 1,903 Bytes
526e0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from PIL import Image
from torchvision import transforms
import torch
import random
import os
from models.structure.model import SketchKeras
from safetensors.torch import load_model
import cv2
import numpy as np


path_to_weights = os.path.join(
    os.path.dirname(__file__), "models/weights/sketch_keras.st"
)
model = SketchKeras()
load_model(model, path_to_weights)
model.eval()


def preprocess(img):
    h, w, c = img.shape
    blurred = cv2.GaussianBlur(img, (0, 0), 3)
    highpass = img.astype(int) - blurred.astype(int)
    highpass = highpass.astype(float) / 128.0
    highpass /= np.max(highpass)

    ret = np.zeros((512, 512, 3), dtype=float)
    ret[0:h, 0:w, 0:c] = highpass
    return ret


def postprocess(pred, thresh=0.18, smooth=False):
    assert thresh <= 1.0 and thresh >= 0.0

    pred = np.amax(pred, 0)
    pred[pred < thresh] = 0
    pred = 1 - pred
    pred *= 255
    pred = np.clip(pred, 0, 255).astype(np.uint8)
    if smooth:
        pred = cv2.medianBlur(pred, 3)
    return pred


def output_sketch(img):
    # resize
    height, width = float(img.shape[0]), float(img.shape[1])
    if width > height:
        new_width, new_height = (512, int(512 / width * height))
    else:
        new_width, new_height = (int(512 / height * width), 512)
    img = cv2.resize(img, (new_width, new_height))

    img = preprocess(img)
    x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2)
    x = torch.tensor(x).float()

    with torch.no_grad():
        pred = model(x)

    pred = pred.squeeze()

    # postprocess
    output = pred.cpu().detach().numpy()
    output = postprocess(output, thresh=0.1, smooth=False)
    output = output[:new_height, :new_width]

    return output


gr.Interface(
    title="Turn Any Image Into a Sketch with SketchKeras",
    fn=output_sketch,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Image(type="numpy"),
).launch()