bodypartxr / app.py
Jason Adrian
bodypartxr classifier
d360108
raw
history blame
No virus
2.14 kB
import gradio as gr
import torch
from torchvision.transforms import transforms
import numpy as np
from resnet18 import ResNet18
model = ResNet18(1, 5)
checkpoint = torch.load('C:\jason\semester 8\Magang\Hugging-face-bodypartxr\bodypartxr\acc=0.94.ckpt')
# The state dict will contains net.layer_name
# Our model doesn't contains `net.` so we have to rename it
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
if 'net.' in key:
state_dict[key.replace('net.', '')] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.eval()
class_names = ['abdominal', 'adult', 'others', 'pediatric', 'spine']
class_names.sort()
transformation_pipeline = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.50807575], std=[0.20823])
])
def preprocess_image(image: np.ndarray):
"""Preprocess the input image.
Note that the input image is in RGB mode.
Parameters
----------
image: np.ndarray
Input image from callback.
"""
image = transformation_pipeline(image)
image = torch.unsqueeze(image, 0)
return image
def image_classifier(inp):
"""Image Classifier Function.
Parameters
----------
inp: Optional[np.ndarray] = None
Input image from callback
Returns
-------
Dict
A dictionary class names and its probability
"""
# If input not valid, return dummy data or raise error
if inp is None:
return {'cat': 0.3, 'dog': 0.7}
# preprocess
image = preprocess_image(inp)
image = image.to(dtype=torch.float32)
# inference
result = model(image)
# postprocess
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
result = result[0].detach().numpy().tolist() # take the first batch
labeled_result = {name:score for name, score in zip(class_names, result)}
return labeled_result
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
demo.launch()