Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
import torch | |
import random | |
import os | |
from models.structure.model import SketchKeras | |
from safetensors.torch import load_model | |
import cv2 | |
import numpy as np | |
path_to_weights = os.path.join( | |
os.path.dirname(__file__), "models/weights/sketch_keras.st" | |
) | |
model = SketchKeras() | |
load_model(model, path_to_weights) | |
model.eval() | |
def preprocess(img): | |
h, w, c = img.shape | |
blurred = cv2.GaussianBlur(img, (0, 0), 3) | |
highpass = img.astype(int) - blurred.astype(int) | |
highpass = highpass.astype(float) / 128.0 | |
highpass /= np.max(highpass) | |
ret = np.zeros((512, 512, 3), dtype=float) | |
ret[0:h, 0:w, 0:c] = highpass | |
return ret | |
def postprocess(pred, thresh=0.18, smooth=False): | |
assert thresh <= 1.0 and thresh >= 0.0 | |
pred = np.amax(pred, 0) | |
pred[pred < thresh] = 0 | |
pred = 1 - pred | |
pred *= 255 | |
pred = np.clip(pred, 0, 255).astype(np.uint8) | |
if smooth: | |
pred = cv2.medianBlur(pred, 3) | |
return pred | |
def output_sketch(img): | |
# resize | |
height, width = float(img.shape[0]), float(img.shape[1]) | |
if width > height: | |
new_width, new_height = (512, int(512 / width * height)) | |
else: | |
new_width, new_height = (int(512 / height * width), 512) | |
img = cv2.resize(img, (new_width, new_height)) | |
img = preprocess(img) | |
x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2) | |
x = torch.tensor(x).float() | |
with torch.no_grad(): | |
pred = model(x) | |
pred = pred.squeeze() | |
# postprocess | |
output = pred.cpu().detach().numpy() | |
output = postprocess(output, thresh=0.1, smooth=False) | |
output = output[:new_height, :new_width] | |
return output | |
gr.Interface( | |
title="Turn Any Image Into a Sketch with SketchKeras", | |
fn=output_sketch, | |
inputs=gr.Image(type="numpy"), | |
outputs=gr.Image(type="numpy"), | |
).launch() | |