File size: 2,333 Bytes
1603bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be173c0
 
 
1603bbc
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from PIL import Image
from tokenizer_base import Tokenizer
import pathlib
import os
import gradio as gr
from huggingface_hub import Repository

repo = Repository(
    local_dir="secret_models",
    repo_type="model",
    clone_from="docparser/captcha",
    token=True
)
repo.git_pull()

cwd = pathlib.Path(__file__).parent.resolve()
model_file = os.path.join(cwd,"secret_models","captcha.onnx")
img_size = (32,128)
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
tokenizer_base = Tokenizer(charset)

def get_transform(img_size):
        transforms = []
        transforms.extend([
            T.Resize(img_size, T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])
        return T.Compose(transforms)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def initialize_model(model_file):
    transform = get_transform(img_size)
    # Onnx model loading
    onnx_model = onnx.load(model_file)
    onnx.checker.check_model(onnx_model)
    ort_session = rt.InferenceSession(model_file)
    return transform,ort_session 

def get_text(img_org):
    # img_org = Image.open(image_path)
    # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
    x = transform(img_org.convert('RGB')).unsqueeze(0)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    logits = ort_session.run(None, ort_inputs)[0]
    probs = torch.tensor(logits).softmax(-1)
    preds, probs = tokenizer_base.decode(probs)
    preds = preds[0]
    print(preds)
    return preds

transform,ort_session = initialize_model(model_file=model_file)

gr.Interface(
    get_text,
    inputs=gr.Image(type="pil"),
    outputs=gr.outputs.Textbox(),
    title="Text Captcha Reader",
    examples=["8000.png","11JW29.png","2a8486.jpg","2nbcx.png",
             "000679.png","000HU.png","00Uga.png.jpg","00bAQwhAZU.jpg",
             "00h57kYf.jpg","0EoHdtVb.png","0JS21.png","0p98z.png","10010.png"]
).launch()

# if __name__ == "__main__":
#     image_path = "8000.png"
#     preds,probs = get_text(image_path)
#     print(preds[0])