File size: 3,482 Bytes
e1e7fa2
 
 
 
 
 
 
 
 
de12ba7
e1e7fa2
 
 
de12ba7
 
 
 
e1e7fa2
 
 
 
 
 
 
de12ba7
e1e7fa2
 
de12ba7
e1e7fa2
 
de12ba7
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718237e
de12ba7
e1e7fa2
 
de12ba7
 
 
 
 
 
 
 
e1e7fa2
 
de12ba7
e1e7fa2
718237e
e1e7fa2
de12ba7
 
 
 
 
 
 
 
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
 
de12ba7
e1e7fa2
 
 
 
 
 
 
 
 
58e5f22
 
 
 
 
 
e1e7fa2
 
de12ba7
e1e7fa2
 
 
 
de12ba7
 
b7472a4
e1e7fa2
 
de12ba7
 
 
e1e7fa2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
import requests
import gradio as gr
from tqdm import tqdm
from PIL import Image
from model import Model
from torchvision import transforms
import warnings

warnings.filterwarnings("ignore")


def download_model(
    url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth",
    local_path="model-122000.pth",
):
    # Check if the file exists
    if not os.path.exists(local_path):
        print(f"Downloading file from {url}...")
        # Make a request to the URL
        response = requests.get(url, stream=True)

        # Get the total file size in bytes
        total_size = int(response.headers.get("content-length", 0))

        # Initialize the tqdm progress bar
        progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)

        # Open a local file with write-binary mode
        with open(local_path, "wb") as file:
            for data in response.iter_content(chunk_size=1024):
                # Update the progress bar
                progress_bar.update(len(data))

                # Write the data to the local file
                file.write(data)

        # Close the progress bar
        progress_bar.close()

        print("Download completed.")


def _infer(path_to_checkpoint_file, path_to_input_image):
    model = Model()
    model.restore(path_to_checkpoint_file)
    # model.cuda()
    outstr = ""

    with torch.no_grad():
        transform = transforms.Compose(
            [
                transforms.Resize([64, 64]),
                transforms.CenterCrop([54, 54]),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )

        image = Image.open(path_to_input_image)
        image = image.convert("RGB")
        image = transform(image)
        images = image.unsqueeze(dim=0)  # .cuda()

        (
            length_logits,
            digit1_logits,
            digit2_logits,
            digit3_logits,
            digit4_logits,
            digit5_logits,
        ) = model.eval()(images)

        length_prediction = length_logits.max(1)[1]
        digit1_prediction = digit1_logits.max(1)[1]
        digit2_prediction = digit2_logits.max(1)[1]
        digit3_prediction = digit3_logits.max(1)[1]
        digit4_prediction = digit4_logits.max(1)[1]
        digit5_prediction = digit5_logits.max(1)[1]

        output = [
            digit1_prediction.item(),
            digit2_prediction.item(),
            digit3_prediction.item(),
            digit4_prediction.item(),
            digit5_prediction.item(),
        ]

        for i in range(length_prediction.item()):
            outstr += str(output[i])

    return outstr


def inference(image_path, weight_path="model-122000.pth"):
    try:
        download_model()
    except Exception:
        download_model(
            url="https://www.modelscope.cn/api/v1/models/MuGeminorum/SVHN-Recognition/repo?Revision=master&FilePath=model-122000.pth"
        )

    if not image_path:
        image_path = "./examples/03.png"

    return _infer(weight_path, image_path)


if __name__ == "__main__":
    example_images = ["./examples/03.png", "./examples/457.png", "./examples/2003.png"]

    iface = gr.Interface(
        fn=inference,
        inputs=gr.Image(type="filepath", label="Upload photo"),
        outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
        examples=example_images,
    )

    iface.launch()