bodypartxr / app.py
Jason Adrian
changes on page utils + sample images
92ec4d3
raw
history blame
8.2 kB
import gradio as gr
import torch
from torchvision.transforms import transforms
import numpy as np
from typing import Optional
import torch.nn as nn
import os
from utils import page_utils
class BasicBlock(nn.Module):
"""ResNet Basic Block.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
stride : int, optional
Convolution stride size, by default 1
identity_downsample : Optional[torch.nn.Module], optional
Downsampling layer, by default None
"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
identity_downsample: Optional[torch.nn.Module] = None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = stride,
padding = 1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size = 3,
stride = 1,
padding = 1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.identity_downsample = identity_downsample
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward computation."""
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
# Apply an operation to the identity output.
# Useful to reduce the layer size and match from conv2 output
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class ResNet18(nn.Module):
"""Construct ResNet-18 Model.
Parameters
----------
input_channels : int
Number of input channels
num_classes : int
Number of class outputs
"""
def __init__(self, input_channels, num_classes):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(input_channels,
64, kernel_size = 7,
stride = 2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size = 3,
stride = 2,
padding = 1)
self.layer1 = self._make_layer(64, 64, stride = 1)
self.layer2 = self._make_layer(64, 128, stride = 2)
self.layer3 = self._make_layer(128, 256, stride = 2)
self.layer4 = self._make_layer(256, 512, stride = 2)
# Last layers
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module:
"""Downsampling block to reduce the feature sizes."""
return nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = 2,
padding = 1),
nn.BatchNorm2d(out_channels)
)
def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module:
"""Create sequential basic block."""
identity_downsample = None
# Add downsampling function
if stride != 1:
identity_downsample = self.identity_downsample(in_channels, out_channels)
return nn.Sequential(
BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride),
BasicBlock(out_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
model = ResNet18(1, 5)
checkpoint = torch.load('acc=0.94.ckpt', map_location=torch.device('cpu'))
# The state dict will contains net.layer_name
# Our model doesn't contains `net.` so we have to rename it
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
if 'net.' in key:
state_dict[key.replace('net.', '')] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.eval()
class_names = ['abdominal', 'adult', 'others', 'pediatric', 'spine']
class_names.sort()
examples_dir = "sample"
transformation_pipeline = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.50807575], std=[0.20823])
])
def preprocess_image(image: np.ndarray):
"""Preprocess the input image.
Note that the input image is in RGB mode.
Parameters
----------
image: np.ndarray
Input image from callback.
"""
image = transformation_pipeline(image)
image = torch.unsqueeze(image, 0)
return image
def image_classifier(inp):
"""Image Classifier Function.
Parameters
----------
inp: Optional[np.ndarray] = None
Input image from callback
Returns
-------
Dict
A dictionary class names and its probability
"""
# If input not valid, return dummy data or raise error
if inp is None:
return {'cat': 0.3, 'dog': 0.7}
# preprocess
image = preprocess_image(inp)
image = image.to(dtype=torch.float32)
# inference
result = model(image)
# postprocess
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
result = result[0].detach().numpy().tolist() # take the first batch
labeled_result = {name:score for name, score in zip(class_names, result)}
return labeled_result
# gradio code block for input and output
with gr.Blocks() as app:
gr.Markdown("# Lung Cancer Classification")
with open('index.html', encoding="utf-8") as f:
description = f.read()
# gradio code block for input and output
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",
)) as app:
with gr.Column():
gr.HTML(description)
with gr.Row():
with gr.Column():
inp_img = gr.Image()
with gr.Row():
clear_btn = gr.Button(value="Clear")
process_btn = gr.Button(value="Process", variant="primary")
with gr.Column():
out_txt = gr.Label(label="Probabilities", num_top_classes=3)
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
clear_btn.click(lambda:(
gr.update(value=None),
gr.update(value=None)
),
inputs=None,
outputs=[inp_img, out_txt])
gr.Markdown("## Image Examples")
gr.Examples(
examples=[os.path.join(examples_dir, "1.2.392.200036.9125.4.0.1964921730.2349552188.1786966286.dcm.jpeg"),
os.path.join(examples_dir, "1b6a707131f787fe37d3ea40d2011d43.dicom.jpeg"),
os.path.join(examples_dir, "2e3204c2bb7a8fcdd6ec1ed547e2967e.dicom.jpeg"),
os.path.join(examples_dir, "10.127.133.1137.156.1251.20190404101039.dcm.jpeg"),
os.path.join(examples_dir, "badaec3e4d5f382ebf0b51ba2c917cea.dicom.jpeg"),
],
inputs=inp_img,
outputs=out_txt,
fn=image_classifier,
cache_examples=False,
)
# demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
app.launch(share=True)