yolov5-classify / app.py
akashAD's picture
Update app.py
2c78f06
raw
history blame
No virus
1.4 kB
import torch
from models.common import DetectMultiBackend
from torchvision import transforms
import gradio as gr
import requests
from PIL import Image
# weights='/content/drive/MyDrive/yolov5/yolov5s-cls.pt'
# model = DetectMultiBackend(weights)
model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5m-cls.pt').eval() # load from PyTorch Hub
model.classify = True
# load imagenet 1000 labels
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def preprocess_image(inp):
# Define the preprocessing steps
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply the preprocessing steps to the image
image = preprocess(inp)
# Convert the image to a PyTorch tensor
image = torch.tensor(image).unsqueeze(0)
return image
def predict(inp):
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(preprocess_image(inp))[0], dim=0)
print(prediction)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="label",labels=labels).launch()
#outputs=gr.Label(num_top_classes=5))