Spaces:
Runtime error
Runtime error
File size: 1,903 Bytes
526e0a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from PIL import Image
from torchvision import transforms
import torch
import random
import os
from models.structure.model import SketchKeras
from safetensors.torch import load_model
import cv2
import numpy as np
path_to_weights = os.path.join(
os.path.dirname(__file__), "models/weights/sketch_keras.st"
)
model = SketchKeras()
load_model(model, path_to_weights)
model.eval()
def preprocess(img):
h, w, c = img.shape
blurred = cv2.GaussianBlur(img, (0, 0), 3)
highpass = img.astype(int) - blurred.astype(int)
highpass = highpass.astype(float) / 128.0
highpass /= np.max(highpass)
ret = np.zeros((512, 512, 3), dtype=float)
ret[0:h, 0:w, 0:c] = highpass
return ret
def postprocess(pred, thresh=0.18, smooth=False):
assert thresh <= 1.0 and thresh >= 0.0
pred = np.amax(pred, 0)
pred[pred < thresh] = 0
pred = 1 - pred
pred *= 255
pred = np.clip(pred, 0, 255).astype(np.uint8)
if smooth:
pred = cv2.medianBlur(pred, 3)
return pred
def output_sketch(img):
# resize
height, width = float(img.shape[0]), float(img.shape[1])
if width > height:
new_width, new_height = (512, int(512 / width * height))
else:
new_width, new_height = (int(512 / height * width), 512)
img = cv2.resize(img, (new_width, new_height))
img = preprocess(img)
x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2)
x = torch.tensor(x).float()
with torch.no_grad():
pred = model(x)
pred = pred.squeeze()
# postprocess
output = pred.cpu().detach().numpy()
output = postprocess(output, thresh=0.1, smooth=False)
output = output[:new_height, :new_width]
return output
gr.Interface(
title="Turn Any Image Into a Sketch with SketchKeras",
fn=output_sketch,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="numpy"),
).launch()
|