img_to_sketch / .history /app_20231208160008.py
pawlo2013's picture
added new title
526e0a6
import gradio as gr
from PIL import Image
from torchvision import transforms
import torch
import random
import os
from models.structure.model import SketchKeras
from safetensors.torch import load_model
import cv2
import numpy as np
path_to_weights = os.path.join(
os.path.dirname(__file__), "models/weights/sketch_keras.st"
)
model = SketchKeras()
load_model(model, path_to_weights)
model.eval()
def preprocess(img):
h, w, c = img.shape
blurred = cv2.GaussianBlur(img, (0, 0), 3)
highpass = img.astype(int) - blurred.astype(int)
highpass = highpass.astype(float) / 128.0
highpass /= np.max(highpass)
ret = np.zeros((512, 512, 3), dtype=float)
ret[0:h, 0:w, 0:c] = highpass
return ret
def postprocess(pred, thresh=0.18, smooth=False):
assert thresh <= 1.0 and thresh >= 0.0
pred = np.amax(pred, 0)
pred[pred < thresh] = 0
pred = 1 - pred
pred *= 255
pred = np.clip(pred, 0, 255).astype(np.uint8)
if smooth:
pred = cv2.medianBlur(pred, 3)
return pred
def output_sketch(img):
# resize
height, width = float(img.shape[0]), float(img.shape[1])
if width > height:
new_width, new_height = (512, int(512 / width * height))
else:
new_width, new_height = (int(512 / height * width), 512)
img = cv2.resize(img, (new_width, new_height))
img = preprocess(img)
x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2)
x = torch.tensor(x).float()
with torch.no_grad():
pred = model(x)
pred = pred.squeeze()
# postprocess
output = pred.cpu().detach().numpy()
output = postprocess(output, thresh=0.1, smooth=False)
output = output[:new_height, :new_width]
return output
gr.Interface(
title="Turn Any Image Into a Sketch with SketchKeras",
fn=output_sketch,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="numpy"),
).launch()