import os import torch import requests import gradio as gr from tqdm import tqdm from PIL import Image from model import Model from torchvision import transforms import warnings warnings.filterwarnings("ignore") def download_model(url="https://www.modelscope.cn/api/v1/models/MuGeminorum/SVHN-Recognition/repo?Revision=master&FilePath=model-122000.pth", local_path="model-122000.pth"): # Check if the file exists if not os.path.exists(local_path): print(f"Downloading file from {url}...") # Make a request to the URL response = requests.get(url, stream=True) # Get the total file size in bytes total_size = int(response.headers.get('content-length', 0)) # Initialize the tqdm progress bar progress_bar = tqdm(total=total_size, unit='B', unit_scale=True) # Open a local file with write-binary mode with open(local_path, 'wb') as file: for data in response.iter_content(chunk_size=1024): # Update the progress bar progress_bar.update(len(data)) # Write the data to the local file file.write(data) # Close the progress bar progress_bar.close() print("Download completed.") def _infer(path_to_checkpoint_file, path_to_input_image): model = Model() model.restore(path_to_checkpoint_file) model.cuda() outstr = '' with torch.no_grad(): transform = transforms.Compose([ transforms.Resize([64, 64]), transforms.CenterCrop([54, 54]), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) image = Image.open(path_to_input_image) image = image.convert('RGB') image = transform(image) images = image.unsqueeze(dim=0).cuda() length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images) length_prediction = length_logits.max(1)[1] digit1_prediction = digit1_logits.max(1)[1] digit2_prediction = digit2_logits.max(1)[1] digit3_prediction = digit3_logits.max(1)[1] digit4_prediction = digit4_logits.max(1)[1] digit5_prediction = digit5_logits.max(1)[1] output = [ digit1_prediction.item(), digit2_prediction.item(), digit3_prediction.item(), digit4_prediction.item(), digit5_prediction.item() ] for i in range(length_prediction.item()): outstr += str(output[i]) return outstr def inference(image_path, weight_path="model-122000.pth"): download_model() if not image_path: image_path = '457.png' return _infer(weight_path, image_path) if __name__ == '__main__': iface = gr.Interface( fn=inference, inputs=gr.Image(type='filepath'), outputs=gr.Textbox() ) iface.launch()