import gradio as gr from PIL import Image from torchvision import transforms import torch import random import os from models.structure.model import SketchKeras from safetensors.torch import load_model import cv2 import numpy as np path_to_weights = os.path.join( os.path.dirname(__file__), "models/weights/sketch_keras.st" ) model = SketchKeras() load_model(model, path_to_weights) model.eval() def preprocess(img): h, w, c = img.shape blurred = cv2.GaussianBlur(img, (0, 0), 3) highpass = img.astype(int) - blurred.astype(int) highpass = highpass.astype(float) / 128.0 highpass /= np.max(highpass) ret = np.zeros((512, 512, 3), dtype=float) ret[0:h, 0:w, 0:c] = highpass return ret def postprocess(pred, thresh=0.18, smooth=False): assert thresh <= 1.0 and thresh >= 0.0 pred = np.amax(pred, 0) pred[pred < thresh] = 0 pred = 1 - pred pred *= 255 pred = np.clip(pred, 0, 255).astype(np.uint8) if smooth: pred = cv2.medianBlur(pred, 3) return pred def output_sketch(img): # resize height, width = float(img.shape[0]), float(img.shape[1]) if width > height: new_width, new_height = (512, int(512 / width * height)) else: new_width, new_height = (int(512 / height * width), 512) img = cv2.resize(img, (new_width, new_height)) img = preprocess(img) x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2) x = torch.tensor(x).float() with torch.no_grad(): pred = model(x) pred = pred.squeeze() # postprocess output = pred.cpu().detach().numpy() output = postprocess(output, thresh=0.1, smooth=False) output = output[:new_height, :new_width] return output gr.Interface( title="Turn Any Image Into a Sketch with SketchKeras", fn=output_sketch, inputs=gr.Image(type="numpy"), outputs=gr.Image(type="numpy"), ).launch()