bone-fracture / app.py
Maria-Dolgaya's picture
Update app.py
280c7c8 verified
raw
history blame
No virus
1.59 kB
import gradio as gr
import torch
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn
from PIL import Image # pip install pillow
labels = ['fractured','not fractured']
# Same data transformation that was used for inputs (except data augmentation)
imgSize = 128
data_transform = transforms.Compose([
transforms.Resize(size=(imgSize, imgSize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
# Loading Model for Inference with state_dict (recommended)
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(in_features=512, out_features=len(labels))
model.load_state_dict(torch.load("model.pth",map_location=torch.device('cpu')))
model.eval()
def predict(img):
X = data_transform(img).unsqueeze(0) # returns tensor
with torch.no_grad():
predictions = model(X).flatten()
predictions = torch.nn.functional.softmax(predictions)
confidences = {labels[i]: float(predictions[i]) for i in range(len(labels))}
return confidences
title = "Bone Fractures"
description = "Bone fractures classifier trained on the Kaggle dataset using Resnet18"
demo=gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=len(labels)),
title=title,
description=description,
examples=["2.jpg", "6.jpg"])
demo.launch('share=True')