|
import json |
|
|
|
import gradio as gr |
|
import numpy |
|
import onnxruntime as ort |
|
from PIL import Image |
|
|
|
ort_sess = ort.InferenceSession('tiny_letter_classifier_v2_q8quant.onnx') |
|
|
|
|
|
|
|
def get_bounds(img): |
|
|
|
|
|
left = img.shape[1] |
|
right = 0 |
|
top = img.shape[0] |
|
bottom = 0 |
|
min_color = numpy.min(img) |
|
max_color = numpy.max(img) |
|
mean_color = 0.5*(min_color+max_color) |
|
|
|
for y in range(0, img.shape[0]): |
|
for x in range(0, img.shape[1]): |
|
if img[y,x] > mean_color: |
|
left = min(left, x) |
|
right = max(right, x) |
|
top = min(top, y) |
|
bottom = max(bottom, y) |
|
return (top, bottom, left, right) |
|
|
|
def resize_maxpool(img, out_width: int, out_height: int): |
|
out = numpy.zeros((out_height, out_width), dtype=img.dtype) |
|
scale_factor_y = img.shape[0] // out_height |
|
scale_factor_x = img.shape[1] // out_width |
|
for y in range(0, out.shape[0]): |
|
for x in range(0, out.shape[1]): |
|
out[y,x] = numpy.max(img[y*scale_factor_y:(y+1)*scale_factor_y, x*scale_factor_x:(x+1)*scale_factor_x]) |
|
return out |
|
|
|
def process_input(input_msg): |
|
img = input_msg["composite"] |
|
|
|
img_mean = 0.5 * (numpy.max(img) + numpy.min(img)) |
|
img = 1.0 * (img < img_mean) |
|
|
|
|
|
img = resize_maxpool(img, 32, 32) |
|
img = numpy.expand_dims(img, axis=0) |
|
return img |
|
|
|
def softmax(arr): |
|
arr = arr - numpy.max(arr) |
|
return numpy.exp(arr) / numpy.sum(numpy.exp(arr), axis=-1) |
|
|
|
def normalize(arr): |
|
arr = numpy.atleast_2d(arr) |
|
if arr.shape[0] == 1: |
|
magnitude = arr @ arr.T |
|
elif arr.shape[1] == 1: |
|
magnitude = arr.T @ arr |
|
return arr / magnitude |
|
|
|
def predict(input_img): |
|
img = process_input(input_img) |
|
class_preds = ort_sess.run(None, {'input': img.astype(numpy.float32)})[0] |
|
class_preds = softmax(class_preds)[0] |
|
class_idx_to_name = list("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") |
|
max_class_idx = numpy.argmax(class_preds) |
|
|
|
text_out = json.dumps({class_idx_to_name[i]: "#"*int(10*j) for i,j in enumerate(class_preds)}, indent=2) |
|
return Image.fromarray(numpy.clip((img[0] * 254), 0, 255).astype(numpy.uint8)), f"Pred: {class_idx_to_name[max_class_idx]}: {class_preds[max_class_idx]}", text_out |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
|
|
|
|
gr.Sketchpad( |
|
width=320, height=320, |
|
canvas_size=(320, 320), |
|
sources = ["upload", "clipboard"], |
|
layers=False, |
|
image_mode='L', |
|
type='numpy', |
|
), |
|
], |
|
outputs=["image", "text", "text"], |
|
) |
|
|
|
demo.launch(share=True) |
|
|
|
|