Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torchvision.transforms import transforms | |
import numpy as np | |
from typing import Optional | |
import torch.nn as nn | |
import os | |
from utils import page_utils | |
class BasicBlock(nn.Module): | |
"""ResNet Basic Block. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels | |
out_channels : int | |
Number of output channels | |
stride : int, optional | |
Convolution stride size, by default 1 | |
identity_downsample : Optional[torch.nn.Module], optional | |
Downsampling layer, by default None | |
""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 1, | |
identity_downsample: Optional[torch.nn.Module] = None): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, | |
out_channels, | |
kernel_size = 3, | |
stride = stride, | |
padding = 1) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv2d(out_channels, | |
out_channels, | |
kernel_size = 3, | |
stride = 1, | |
padding = 1) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.identity_downsample = identity_downsample | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Apply forward computation.""" | |
identity = x | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
# Apply an operation to the identity output. | |
# Useful to reduce the layer size and match from conv2 output | |
if self.identity_downsample is not None: | |
identity = self.identity_downsample(identity) | |
x += identity | |
x = self.relu(x) | |
return x | |
class ResNet18(nn.Module): | |
"""Construct ResNet-18 Model. | |
Parameters | |
---------- | |
input_channels : int | |
Number of input channels | |
num_classes : int | |
Number of class outputs | |
""" | |
def __init__(self, input_channels, num_classes): | |
super(ResNet18, self).__init__() | |
self.conv1 = nn.Conv2d(input_channels, | |
64, kernel_size = 7, | |
stride = 2, padding=3) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu = nn.ReLU() | |
self.maxpool = nn.MaxPool2d(kernel_size = 3, | |
stride = 2, | |
padding = 1) | |
self.layer1 = self._make_layer(64, 64, stride = 1) | |
self.layer2 = self._make_layer(64, 128, stride = 2) | |
self.layer3 = self._make_layer(128, 256, stride = 2) | |
self.layer4 = self._make_layer(256, 512, stride = 2) | |
# Last layers | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc = nn.Linear(512, num_classes) | |
def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: | |
"""Downsampling block to reduce the feature sizes.""" | |
return nn.Sequential( | |
nn.Conv2d(in_channels, | |
out_channels, | |
kernel_size = 3, | |
stride = 2, | |
padding = 1), | |
nn.BatchNorm2d(out_channels) | |
) | |
def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: | |
"""Create sequential basic block.""" | |
identity_downsample = None | |
# Add downsampling function | |
if stride != 1: | |
identity_downsample = self.identity_downsample(in_channels, out_channels) | |
return nn.Sequential( | |
BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), | |
BasicBlock(out_channels, out_channels) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.avgpool(x) | |
x = x.view(x.shape[0], -1) | |
x = self.fc(x) | |
return x | |
model = ResNet18(1, 5) | |
checkpoint = torch.load('acc=0.94.ckpt', map_location=torch.device('cpu')) | |
# The state dict will contains net.layer_name | |
# Our model doesn't contains `net.` so we have to rename it | |
state_dict = checkpoint['state_dict'] | |
for key in list(state_dict.keys()): | |
if 'net.' in key: | |
state_dict[key.replace('net.', '')] = state_dict[key] | |
del state_dict[key] | |
model.load_state_dict(state_dict) | |
model.eval() | |
class_names = ['abdominal', 'adult', 'others', 'pediatric', 'spine'] | |
class_names.sort() | |
examples_dir = "sample" | |
transformation_pipeline = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Grayscale(num_output_channels=1), | |
transforms.CenterCrop((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.50807575], std=[0.20823]) | |
]) | |
def preprocess_image(image: np.ndarray): | |
"""Preprocess the input image. | |
Note that the input image is in RGB mode. | |
Parameters | |
---------- | |
image: np.ndarray | |
Input image from callback. | |
""" | |
image = transformation_pipeline(image) | |
image = torch.unsqueeze(image, 0) | |
return image | |
def image_classifier(inp): | |
"""Image Classifier Function. | |
Parameters | |
---------- | |
inp: Optional[np.ndarray] = None | |
Input image from callback | |
Returns | |
------- | |
Dict | |
A dictionary class names and its probability | |
""" | |
# If input not valid, return dummy data or raise error | |
if inp is None: | |
return {'cat': 0.3, 'dog': 0.7} | |
# preprocess | |
image = preprocess_image(inp) | |
image = image.to(dtype=torch.float32) | |
# inference | |
result = model(image) | |
# postprocess | |
result = torch.nn.functional.softmax(result, dim=1) # apply softmax | |
result = result[0].detach().numpy().tolist() # take the first batch | |
labeled_result = {name:score for name, score in zip(class_names, result)} | |
return labeled_result | |
# gradio code block for input and output | |
with gr.Blocks() as app: | |
gr.Markdown("# Lung Cancer Classification") | |
with open('index.html', encoding="utf-8") as f: | |
description = f.read() | |
# gradio code block for input and output | |
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( | |
button_primary_background_fill="*primary_600", | |
button_primary_background_fill_hover="*primary_500", | |
button_primary_text_color="white", | |
)) as app: | |
with gr.Column(): | |
gr.HTML(description) | |
with gr.Row(): | |
with gr.Column(): | |
inp_img = gr.Image() | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
process_btn = gr.Button(value="Process", variant="primary") | |
with gr.Column(): | |
out_txt = gr.Label(label="Probabilities", num_top_classes=3) | |
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt) | |
clear_btn.click(lambda:( | |
gr.update(value=None), | |
gr.update(value=None) | |
), | |
inputs=None, | |
outputs=[inp_img, out_txt]) | |
gr.Markdown("## Image Examples") | |
gr.Examples( | |
examples=[os.path.join(examples_dir, "1.2.392.200036.9125.4.0.1964921730.2349552188.1786966286.dcm.jpeg"), | |
os.path.join(examples_dir, "1b6a707131f787fe37d3ea40d2011d43.dicom.jpeg"), | |
os.path.join(examples_dir, "2e3204c2bb7a8fcdd6ec1ed547e2967e.dicom.jpeg"), | |
os.path.join(examples_dir, "10.127.133.1137.156.1251.20190404101039.dcm.jpeg"), | |
os.path.join(examples_dir, "badaec3e4d5f382ebf0b51ba2c917cea.dicom.jpeg"), | |
], | |
inputs=inp_img, | |
outputs=out_txt, | |
fn=image_classifier, | |
cache_examples=False, | |
) | |
# demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label") | |
app.launch(share=True) |