import torch import os import gradio as gr import numpy as np from torchvision.transforms import transforms from typing import Optional import torch.nn as nn from utils import page_utils class BasicBlock(nn.Module): """ ResNet Basic Block. This class defines a basic building block for ResNet architectures. It consists of two convolutional layers with batch normalization and a ReLU activation function. Optionally, it can include an identity downsample layer to match the dimensions of the input and output when the stride is not 1. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. stride : int, optional Convolution stride size, by default 1. identity_downsample : Optional[torch.nn.Module], optional Downsampling layer, by default None. Methods ------- forward(x: torch.Tensor) -> torch.Tensor: Apply forward computation. """ def __init__(self, in_channels: int, out_channels: int, stride: int = 1, identity_downsample: Optional[torch.nn.Module] = None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1) self.bn2 = nn.BatchNorm2d(out_channels) self.identity_downsample = identity_downsample def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward computation.""" identity = x x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) # Apply an operation to the identity output. # Useful to reduce the layer size and match from conv2 output if self.identity_downsample is not None: identity = self.identity_downsample(identity) x += identity x = self.relu(x) return x class ResNet18(nn.Module): """ Construct ResNet-18 Model. This class defines the ResNet-18 architecture, including convolutional layers, basic blocks, and fully connected layers for classification. Parameters ---------- input_channels : int Number of input channels. num_classes : int Number of class outputs. Methods ------- forward(x: torch.Tensor) -> torch.Tensor: Apply forward computation. """ def __init__(self, input_channels, num_classes): super(ResNet18, self).__init__() self.conv1 = nn.Conv2d(input_channels, 64, kernel_size = 7, stride = 2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1) self.layer1 = self._make_layer(64, 64, stride = 1) self.layer2 = self._make_layer(64, 128, stride = 2) self.layer3 = self._make_layer(128, 256, stride = 2) self.layer4 = self._make_layer(256, 512, stride = 2) # Last layers self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: """Downsampling block to reduce the feature sizes.""" return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1), nn.BatchNorm2d(out_channels) ) def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: """Create sequential basic block.""" identity_downsample = None # Add downsampling function if stride != 1: identity_downsample = self.identity_downsample(in_channels, out_channels) return nn.Sequential( BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), BasicBlock(out_channels, out_channels) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.shape[0], -1) x = self.fc(x) return x model = ResNet18(3, 3) checkpoint = torch.load('epoch=49-step=1750.ckpt', map_location=torch.device('cpu')) state_dict = checkpoint['state_dict'] for key in list(state_dict.keys()): if 'net.' in key: state_dict[key.replace('net.', '')] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) model.eval() class_names = ['benign', 'malignant', 'normal'] class_names.sort() example_dir = "SAMPLES" transformation_pipeline = transforms.Compose([ transforms.ToPILImage(), transforms.Grayscale(num_output_channels=3), transforms.Resize((256, 256)), transforms.RandomRotation(20), transforms.ToTensor(), transforms.Normalize(mean=[0.233827, 0.2338219, 0.23378967], std=[0.2016421162328173, 0.20164345656093885, 0.20160390432148026]) ]) def preprocess_image(image: np.ndarray): """Preprocess the input image. Note that the input image is in RGB mode. Parameters ---------- image: np.ndarray Input image from callback. """ image = transformation_pipeline(image) image = torch.unsqueeze(image, 0) return image def image_classifier(inp): """Image Classifier Function. Parameters ---------- inp: Optional[np.ndarray] = None Input image from callback Returns ------- Dict A dictionary class names and its probability """ # If input not valid, return dummy data or raise error if inp is None: return {gr.Error()} # preprocess image = preprocess_image(inp) image = image.to(dtype=torch.float32) # inference result = model(image) # postprocess result = torch.nn.functional.softmax(result, dim=1) # apply softmax result = result[0].detach().numpy().tolist() # take the first batch labeled_result = {name:score for name, score in zip(class_names, result)} return labeled_result with open('index.html', encoding="utf-8") as f: description = f.read() with open('author.html', encoding="utf-8") as author: author_info = author.read() with gr.Blocks( theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( button_primary_background_fill="*primary_600", button_primary_background_fill_hover="*primary_500", button_primary_text_color="white",) ) as demo: with gr.Column(): gr.HTML(description) with gr.Row(): # Build a row with gr.Column(): # build a column section as the first item inp = gr.Image(label="image", image_mode="RGB") # build an image as the first column item with gr.Row(): # build a row section as the second item clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit") # build a label as the second item out = gr.Label(label="prediction", num_top_classes=3) # Define buttons functionalities submit_btn.click(fn=image_classifier, inputs=inp, outputs=out) clear_btn.click( lambda: ( gr.update(value=None), gr.update(value=None), ), inputs=None, outputs=[inp, out] ) # Add examples gr.Markdown("## Image Examples") gr.Examples( example_dir, inputs=[inp], label="Image Examples", cache_examples=False ) with gr.Column(): gr.HTML(author_info) demo.launch(share=True)