Spaces:
Build error
Build error
| import torchvision | |
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| import numpy as np | |
| import gradio as gr | |
| def predict(img_path,model=None): | |
| if model is None: | |
| pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT | |
| model=torchvision.models.resnet18(weights=pretrained_weights_resnet18) | |
| class_names=pretrained_weights_resnet18.meta["categories"] | |
| transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) | |
| if isinstance(img_path,np.ndarray): | |
| image=Image.fromarray(img_path).convert("RGB") | |
| else: | |
| image=Image.open(img_path).convert("RGB") | |
| img_transform=transform(image).unsqueeze(0) | |
| model.eval() | |
| with torch.inference_mode(): | |
| logit=model(img_transform) | |
| pred_prob=torch.softmax(logit,dim=1).squeeze().numpy() | |
| predict_dict={} | |
| for i in range(len(class_names)): | |
| predict_dict[class_names[i]]=float(pred_prob[i]) | |
| return predict_dict | |
| demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3)) | |
| if __name__ == "__main__": | |
| demo.launch() |