Spaces:
Running
Running
| import os | |
| import torch | |
| import random | |
| import warnings | |
| import modelscope | |
| import huggingface_hub | |
| import gradio as gr | |
| from PIL import Image | |
| from model import Model | |
| from torchvision import transforms | |
| EN_US = os.getenv("LANG") != "zh_CN.UTF-8" | |
| MODEL_DIR = ( | |
| huggingface_hub.snapshot_download( | |
| "Genius-Society/svhn", | |
| cache_dir="./__pycache__", | |
| ) | |
| if EN_US | |
| else modelscope.snapshot_download( | |
| "Genius-Society/svhn", | |
| cache_dir="./__pycache__", | |
| ) | |
| ) | |
| ZH2EN = { | |
| "上传图片": "Upload an image", | |
| "状态栏": "Status", | |
| "选择模型": "Select a model", | |
| "识别结果": "Recognition result", | |
| "门牌号识别": "Door Number Recognition", | |
| } | |
| def _L(zh_txt: str): | |
| return ZH2EN[zh_txt] if EN_US else zh_txt | |
| def infer(input_img: str, checkpoint_file: str): | |
| status = "Success" | |
| outstr = "" | |
| try: | |
| model = Model() | |
| model.restore(f"{MODEL_DIR}/{checkpoint_file}") | |
| with torch.no_grad(): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize([64, 64]), | |
| transforms.CenterCrop([54, 54]), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| image = Image.open(input_img) | |
| image = image.convert("RGB") | |
| image = transform(image) | |
| images = image.unsqueeze(dim=0) | |
| ( | |
| length_logits, | |
| digit1_logits, | |
| digit2_logits, | |
| digit3_logits, | |
| digit4_logits, | |
| digit5_logits, | |
| ) = model.eval()(images) | |
| length_prediction = length_logits.max(1)[1] | |
| digit1_prediction = digit1_logits.max(1)[1] | |
| digit2_prediction = digit2_logits.max(1)[1] | |
| digit3_prediction = digit3_logits.max(1)[1] | |
| digit4_prediction = digit4_logits.max(1)[1] | |
| digit5_prediction = digit5_logits.max(1)[1] | |
| output = [ | |
| digit1_prediction.item(), | |
| digit2_prediction.item(), | |
| digit3_prediction.item(), | |
| digit4_prediction.item(), | |
| digit5_prediction.item(), | |
| ] | |
| for i in range(length_prediction.item()): | |
| outstr += str(output[i]) | |
| except Exception as e: | |
| status = f"{e}" | |
| return status, outstr | |
| def get_files(dir_path=MODEL_DIR, ext=".pth"): | |
| files_and_folders = os.listdir(dir_path) | |
| outputs = [] | |
| for file in files_and_folders: | |
| if file.endswith(ext): | |
| outputs.append(file) | |
| return outputs | |
| if __name__ == "__main__": | |
| warnings.filterwarnings("ignore") | |
| models = get_files() | |
| images = get_files(f"{MODEL_DIR}/examples", ".png") | |
| samples = [] | |
| for img in images: | |
| samples.append( | |
| [ | |
| f"{MODEL_DIR}/examples/{img}", | |
| models[random.randint(0, len(models) - 1)], | |
| ] | |
| ) | |
| gr.Interface( | |
| fn=infer, | |
| inputs=[ | |
| gr.Image(label=_L("上传图片"), type="filepath"), | |
| gr.Dropdown( | |
| label=_L("选择模型"), | |
| choices=models, | |
| value=models[0], | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
| gr.Textbox(label=_L("识别结果"), show_copy_button=True), | |
| ], | |
| examples=samples, | |
| title=_L("门牌号识别"), | |
| flagging_mode="never", | |
| cache_examples=False, | |
| ).launch(ssr_mode=False) | |