MuGeminorum
add show copy btn
de12ba7
raw
history blame
No virus
3.48 kB
import os
import torch
import requests
import gradio as gr
from tqdm import tqdm
from PIL import Image
from model import Model
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")
def download_model(
url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth",
local_path="model-122000.pth",
):
# Check if the file exists
if not os.path.exists(local_path):
print(f"Downloading file from {url}...")
# Make a request to the URL
response = requests.get(url, stream=True)
# Get the total file size in bytes
total_size = int(response.headers.get("content-length", 0))
# Initialize the tqdm progress bar
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
# Open a local file with write-binary mode
with open(local_path, "wb") as file:
for data in response.iter_content(chunk_size=1024):
# Update the progress bar
progress_bar.update(len(data))
# Write the data to the local file
file.write(data)
# Close the progress bar
progress_bar.close()
print("Download completed.")
def _infer(path_to_checkpoint_file, path_to_input_image):
model = Model()
model.restore(path_to_checkpoint_file)
# model.cuda()
outstr = ""
with torch.no_grad():
transform = transforms.Compose(
[
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
image = Image.open(path_to_input_image)
image = image.convert("RGB")
image = transform(image)
images = image.unsqueeze(dim=0) # .cuda()
(
length_logits,
digit1_logits,
digit2_logits,
digit3_logits,
digit4_logits,
digit5_logits,
) = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item(),
]
for i in range(length_prediction.item()):
outstr += str(output[i])
return outstr
def inference(image_path, weight_path="model-122000.pth"):
try:
download_model()
except Exception:
download_model(
url="https://www.modelscope.cn/api/v1/models/MuGeminorum/SVHN-Recognition/repo?Revision=master&FilePath=model-122000.pth"
)
if not image_path:
image_path = "./examples/03.png"
return _infer(weight_path, image_path)
if __name__ == "__main__":
example_images = ["./examples/03.png", "./examples/457.png", "./examples/2003.png"]
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type="filepath", label="Upload photo"),
outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
examples=example_images,
)
iface.launch()