Spaces:
Runtime error
Runtime error
import torch | |
import onnx | |
import onnxruntime as rt | |
from torchvision import transforms as T | |
from PIL import Image | |
from tokenizer_base import Tokenizer | |
import pathlib | |
import os | |
import gradio as gr | |
from huggingface_hub import Repository | |
repo = Repository( | |
local_dir="secret_models", | |
repo_type="model", | |
clone_from="docparser/captcha", | |
token=True | |
) | |
repo.git_pull() | |
cwd = pathlib.Path(__file__).parent.resolve() | |
model_file = os.path.join(cwd,"secret_models","captcha.onnx") | |
img_size = (32,128) | |
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
tokenizer_base = Tokenizer(charset) | |
def get_transform(img_size): | |
transforms = [] | |
transforms.extend([ | |
T.Resize(img_size, T.InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(0.5, 0.5) | |
]) | |
return T.Compose(transforms) | |
def to_numpy(tensor): | |
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
def initialize_model(model_file): | |
transform = get_transform(img_size) | |
# Onnx model loading | |
onnx_model = onnx.load(model_file) | |
onnx.checker.check_model(onnx_model) | |
ort_session = rt.InferenceSession(model_file) | |
return transform,ort_session | |
def get_text(img_org): | |
# img_org = Image.open(image_path) | |
# Preprocess. Model expects a batch of images with shape: (B, C, H, W) | |
x = transform(img_org.convert('RGB')).unsqueeze(0) | |
# compute ONNX Runtime output prediction | |
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} | |
logits = ort_session.run(None, ort_inputs)[0] | |
probs = torch.tensor(logits).softmax(-1) | |
preds, probs = tokenizer_base.decode(probs) | |
preds = preds[0] | |
print(preds) | |
return preds | |
transform,ort_session = initialize_model(model_file=model_file) | |
gr.Interface( | |
get_text, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.outputs.Textbox(), | |
title="Text Captcha Reader", | |
examples=["8000.png","11JW29.png","2a8486.jpg","2nbcx.png"] | |
).launch() | |
# if __name__ == "__main__": | |
# image_path = "8000.png" | |
# preds,probs = get_text(image_path) | |
# print(preds[0]) | |