Spaces:
Running
Running
import os | |
import torch | |
import requests | |
import gradio as gr | |
from tqdm import tqdm | |
from PIL import Image | |
from model import Model | |
from torchvision import transforms | |
import warnings | |
warnings.filterwarnings("ignore") | |
def download_model( | |
url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", | |
local_path="model-122000.pth", | |
): | |
# Check if the file exists | |
if not os.path.exists(local_path): | |
print(f"Downloading file from {url}...") | |
# Make a request to the URL | |
response = requests.get(url, stream=True) | |
# Get the total file size in bytes | |
total_size = int(response.headers.get("content-length", 0)) | |
# Initialize the tqdm progress bar | |
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) | |
# Open a local file with write-binary mode | |
with open(local_path, "wb") as file: | |
for data in response.iter_content(chunk_size=1024): | |
# Update the progress bar | |
progress_bar.update(len(data)) | |
# Write the data to the local file | |
file.write(data) | |
# Close the progress bar | |
progress_bar.close() | |
print("Download completed.") | |
def _infer(path_to_checkpoint_file, path_to_input_image): | |
model = Model() | |
model.restore(path_to_checkpoint_file) | |
# model.cuda() | |
outstr = "" | |
with torch.no_grad(): | |
transform = transforms.Compose( | |
[ | |
transforms.Resize([64, 64]), | |
transforms.CenterCrop([54, 54]), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
image = Image.open(path_to_input_image) | |
image = image.convert("RGB") | |
image = transform(image) | |
images = image.unsqueeze(dim=0) # .cuda() | |
( | |
length_logits, | |
digit1_logits, | |
digit2_logits, | |
digit3_logits, | |
digit4_logits, | |
digit5_logits, | |
) = model.eval()(images) | |
length_prediction = length_logits.max(1)[1] | |
digit1_prediction = digit1_logits.max(1)[1] | |
digit2_prediction = digit2_logits.max(1)[1] | |
digit3_prediction = digit3_logits.max(1)[1] | |
digit4_prediction = digit4_logits.max(1)[1] | |
digit5_prediction = digit5_logits.max(1)[1] | |
output = [ | |
digit1_prediction.item(), | |
digit2_prediction.item(), | |
digit3_prediction.item(), | |
digit4_prediction.item(), | |
digit5_prediction.item(), | |
] | |
for i in range(length_prediction.item()): | |
outstr += str(output[i]) | |
return outstr | |
def inference(image_path, weight_path="model-122000.pth"): | |
try: | |
download_model() | |
except Exception: | |
download_model( | |
url="https://www.modelscope.cn/api/v1/models/MuGeminorum/SVHN-Recognition/repo?Revision=master&FilePath=model-122000.pth" | |
) | |
if not image_path: | |
image_path = "./examples/03.png" | |
return _infer(weight_path, image_path) | |
if __name__ == "__main__": | |
example_images = ["./examples/03.png", "./examples/457.png", "./examples/2003.png"] | |
iface = gr.Interface( | |
fn=inference, | |
inputs=gr.Image(type="filepath", label="Upload photo"), | |
outputs=gr.Textbox(label="Recognition result", show_copy_button=True), | |
examples=example_images, | |
) | |
iface.launch() | |