import gradio as gr import torch from torchvision import transforms model = torch.jit.load("../models/mobilenet.pt") model.eval() transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]) CLASSES = ["Ak", "Ala_Idris", "Buzgulu", "Dimnit", "Nazli"] def classify_image(inp): inp = transform(inp).unsqueeze(0) out = model(inp) # print(out.argmax(dim=1, keepdim=True)) # print(out.argmax()) return CLASSES[out.argmax().item()] iface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil", label="Input Image"), outputs="text", examples=[ "../data/app_data/ak.png", "../data/app_data/idris.png", "../data/app_data/buzgulu.png", "../data/app_data/dimnit.png", "../data/app_data/nazli.png", ]) iface.launch()