Spaces:
Running
Running
import os | |
import torch | |
import requests | |
import gradio as gr | |
from tqdm import tqdm | |
from PIL import Image | |
from model import Model | |
from torchvision import transforms | |
import warnings | |
warnings.filterwarnings("ignore") | |
def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", local_path="model-122000.pth"): | |
# Check if the file exists | |
if not os.path.exists(local_path): | |
print(f"Downloading file from {url}...") | |
# Make a request to the URL | |
response = requests.get(url, stream=True) | |
# Get the total file size in bytes | |
total_size = int(response.headers.get('content-length', 0)) | |
# Initialize the tqdm progress bar | |
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True) | |
# Open a local file with write-binary mode | |
with open(local_path, 'wb') as file: | |
for data in response.iter_content(chunk_size=1024): | |
# Update the progress bar | |
progress_bar.update(len(data)) | |
# Write the data to the local file | |
file.write(data) | |
# Close the progress bar | |
progress_bar.close() | |
print("Download completed.") | |
def _infer(path_to_checkpoint_file, path_to_input_image): | |
model = Model() | |
model.restore(path_to_checkpoint_file) | |
# model.cuda() | |
outstr = '' | |
with torch.no_grad(): | |
transform = transforms.Compose([ | |
transforms.Resize([64, 64]), | |
transforms.CenterCrop([54, 54]), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
]) | |
image = Image.open(path_to_input_image) | |
image = image.convert('RGB') | |
image = transform(image) | |
images = image.unsqueeze(dim=0) # .cuda() | |
length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images) | |
length_prediction = length_logits.max(1)[1] | |
digit1_prediction = digit1_logits.max(1)[1] | |
digit2_prediction = digit2_logits.max(1)[1] | |
digit3_prediction = digit3_logits.max(1)[1] | |
digit4_prediction = digit4_logits.max(1)[1] | |
digit5_prediction = digit5_logits.max(1)[1] | |
output = [ | |
digit1_prediction.item(), | |
digit2_prediction.item(), | |
digit3_prediction.item(), | |
digit4_prediction.item(), | |
digit5_prediction.item() | |
] | |
for i in range(length_prediction.item()): | |
outstr += str(output[i]) | |
return outstr | |
def inference(image_path, weight_path="model-122000.pth"): | |
download_model() | |
if not image_path: | |
image_path = '457.png' | |
return _infer(weight_path, image_path) | |
if __name__ == '__main__': | |
iface = gr.Interface( | |
fn=inference, | |
inputs=gr.Image(type='filepath'), | |
outputs=gr.Textbox() | |
) | |
iface.launch() | |