File size: 3,092 Bytes
52d714a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json

import gradio as gr
import numpy
import onnxruntime as ort
from PIL import Image

ort_sess = ort.InferenceSession('tiny_letter_classifier_v2_q8quant.onnx')

# force reload now!

def get_bounds(img):
    # Assumes a BLACK BACKGROUND!
    # White letters on a black background!
    left = img.shape[1]
    right = 0
    top = img.shape[0]
    bottom = 0
    min_color = numpy.min(img)
    max_color = numpy.max(img)
    mean_color = 0.5*(min_color+max_color)
    # Do this the dumb way.
    for y in range(0, img.shape[0]):
        for x in range(0, img.shape[1]):
            if img[y,x] > mean_color:
                left = min(left, x)
                right = max(right, x)
                top = min(top, y)
                bottom = max(bottom, y)
    return (top, bottom, left, right)

def resize_maxpool(img, out_width: int, out_height: int):
    out = numpy.zeros((out_height, out_width), dtype=img.dtype)
    scale_factor_y = img.shape[0] // out_height
    scale_factor_x = img.shape[1] // out_width
    for y in range(0, out.shape[0]):
        for x in range(0, out.shape[1]):
            out[y,x] = numpy.max(img[y*scale_factor_y:(y+1)*scale_factor_y, x*scale_factor_x:(x+1)*scale_factor_x])
    return out

def process_input(input_msg):
    img = input_msg["composite"]
    # Image is inverted.  255 is white, 0 is what's drawn.
    img_mean = 0.5 * (numpy.max(img) + numpy.min(img))
    img = 1.0 * (img < img_mean)  # Invert the image and convert to a float.
    #crop_area = get_bounds(img)
    #img = img[crop_area[0]:crop_area[1]+2, crop_area[2]:crop_area[3]+2]
    img = resize_maxpool(img, 32, 32)
    img = numpy.expand_dims(img, axis=0)  # Unsqueeze
    return img

def softmax(arr):
    arr = arr - numpy.max(arr)
    return numpy.exp(arr) / numpy.sum(numpy.exp(arr), axis=-1)
    
def normalize(arr):
    arr = numpy.atleast_2d(arr)
    if arr.shape[0] == 1:
        magnitude = arr @ arr.T
    elif arr.shape[1] == 1:
        magnitude = arr.T @ arr
    return arr / magnitude

def predict(input_img):
    img = process_input(input_img)
    class_preds = ort_sess.run(None, {'input': img.astype(numpy.float32)})[0]
    class_preds = softmax(class_preds)[0]
    class_idx_to_name = list("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
    max_class_idx = numpy.argmax(class_preds)
    
    text_out = json.dumps({class_idx_to_name[i]: "#"*int(10*j) for i,j in enumerate(class_preds)}, indent=2)
    return Image.fromarray(numpy.clip((img[0] * 254), 0, 255).astype(numpy.uint8)), f"Pred: {class_idx_to_name[max_class_idx]}: {class_preds[max_class_idx]}", text_out
    #return sim[0][0], text_out


demo = gr.Interface(
    fn=predict,
    inputs=[
        #gr.Sketchpad(image_mode='L', type='numpy'),
        #gr.ImageEditor(
        gr.Sketchpad(
            width=320, height=320, 
            canvas_size=(320, 320),
            sources = ["upload", "clipboard"], # Webcam
            layers=False,
            image_mode='L', 
            type='numpy', 
        ),
    ],
    outputs=["image", "text", "text"],
)

demo.launch(share=True)