Muhammad Nouman Khan
update app.py
ad6a979
raw
history blame contribute delete
No virus
940 Bytes
import torch
import gradio as gr
from model import AlexNet
from torchvision import transforms
#More Libraries ...
model_path = './alexnet_model_v1.pth'
model = AlexNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(10)}
return confidences
gr.Interface(fn=predict,
inputs=gr.components.Image(type="pil"),
outputs=gr.components.Label(num_top_classes=5),
examples=["frog.jpeg", "car.jpeg", "cat.jpeg", "ship.jpeg", "dog.jpeg"],
theme="default",
css=".footer{display:none !important}").launch()