Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision.transforms import transforms | |
import numpy as np | |
from resnet18 import ResNet18 | |
model = ResNet18(1, 5) | |
checkpoint = torch.load('C:\jason\semester 8\Magang\Hugging-face-bodypartxr\bodypartxr\acc=0.94.ckpt') | |
# The state dict will contains net.layer_name | |
# Our model doesn't contains `net.` so we have to rename it | |
state_dict = checkpoint['state_dict'] | |
for key in list(state_dict.keys()): | |
if 'net.' in key: | |
state_dict[key.replace('net.', '')] = state_dict[key] | |
del state_dict[key] | |
model.load_state_dict(state_dict) | |
model.eval() | |
class_names = ['abdominal', 'adult', 'others', 'pediatric', 'spine'] | |
class_names.sort() | |
transformation_pipeline = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Grayscale(num_output_channels=1), | |
transforms.CenterCrop((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.50807575], std=[0.20823]) | |
]) | |
def preprocess_image(image: np.ndarray): | |
"""Preprocess the input image. | |
Note that the input image is in RGB mode. | |
Parameters | |
---------- | |
image: np.ndarray | |
Input image from callback. | |
""" | |
image = transformation_pipeline(image) | |
image = torch.unsqueeze(image, 0) | |
return image | |
def image_classifier(inp): | |
"""Image Classifier Function. | |
Parameters | |
---------- | |
inp: Optional[np.ndarray] = None | |
Input image from callback | |
Returns | |
------- | |
Dict | |
A dictionary class names and its probability | |
""" | |
# If input not valid, return dummy data or raise error | |
if inp is None: | |
return {'cat': 0.3, 'dog': 0.7} | |
# preprocess | |
image = preprocess_image(inp) | |
image = image.to(dtype=torch.float32) | |
# inference | |
result = model(image) | |
# postprocess | |
result = torch.nn.functional.softmax(result, dim=1) # apply softmax | |
result = result[0].detach().numpy().tolist() # take the first batch | |
labeled_result = {name:score for name, score in zip(class_names, result)} | |
return labeled_result | |
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label") | |
demo.launch() |