|
import onnxruntime as ort |
|
import numpy |
|
import gradio as gr |
|
from PIL import Image |
|
|
|
ort_sess = ort.InferenceSession('tiny_doodle_embedding.onnx') |
|
|
|
|
|
|
|
def get_bounds(img): |
|
|
|
|
|
left = img.shape[1] |
|
right = 0 |
|
top = img.shape[0] |
|
bottom = 0 |
|
min_color = numpy.min(img) |
|
max_color = numpy.max(img) |
|
mean_color = 0.5*(min_color+max_color) |
|
|
|
for y in range(0, img.shape[0]): |
|
for x in range(0, img.shape[1]): |
|
if img[y,x] > mean_color: |
|
left = min(left, x) |
|
right = max(right, x) |
|
top = min(top, y) |
|
bottom = max(bottom, y) |
|
return (top, bottom, left, right) |
|
|
|
def resize_maxpool(img, out_width: int, out_height: int): |
|
out = numpy.zeros((out_height, out_width), dtype=img.dtype) |
|
scale_factor_y = img.shape[0] // out_height |
|
scale_factor_x = img.shape[1] // out_width |
|
for y in range(0, out.shape[0]): |
|
for x in range(0, out.shape[1]): |
|
out[y,x] = numpy.max(img[y*scale_factor_y:(y+1)*scale_factor_y, x*scale_factor_x:(x+1)*scale_factor_x]) |
|
return out |
|
|
|
def process_input(input_msg): |
|
img = input_msg["composite"] |
|
|
|
img_mean = 0.5 * (numpy.max(img) + numpy.min(img)) |
|
img = 1.0 * (img < img_mean) |
|
crop_area = get_bounds(img) |
|
img = img[crop_area[0]:crop_area[1], crop_area[2]:crop_area[3]] |
|
img = resize_maxpool(img, 32, 32) |
|
|
|
img = numpy.expand_dims(img, axis=0) |
|
return img |
|
|
|
|
|
def compare(input_img_a, input_img_b): |
|
text_out = "" |
|
|
|
img_a = process_input(input_img_a) |
|
img_b = process_input(input_img_b) |
|
|
|
|
|
a_embedding = ort_sess.run(None, {'input': img_a.astype(numpy.float32)})[0] |
|
b_embedding = ort_sess.run(None, {'input': img_b.astype(numpy.float32)})[0] |
|
a_mag = 1.0 |
|
b_mag = 1.0 |
|
a_embedding /= a_mag |
|
b_embedding /= b_mag |
|
text_out += f"img_a_embedding: {a_embedding}\n" |
|
text_out += f"img_b_embedding: {b_embedding}\n" |
|
sim = numpy.dot(a_embedding , b_embedding.T) |
|
print(sim) |
|
print(text_out) |
|
return Image.fromarray(numpy.clip((numpy.hstack([img_a[0], img_b[0]]) * 254), 0, 255).astype(numpy.uint8)), sim[0][0], text_out |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=compare, |
|
inputs=[ |
|
gr.Sketchpad(image_mode='L', type='numpy'), |
|
gr.Sketchpad(image_mode='L', type='numpy'), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
], |
|
outputs=["image", "number", "text"], |
|
) |
|
|
|
demo.launch(share=True) |
|
|
|
|