rsna_boneage / app.py
felipekitamura's picture
Create app.py
a75d95d verified
import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
trained_weights_path = "epoch=009.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
# recreate the model
class BoneAgeModel(lightning.LightningModule):
def __init__(self, net, optimizer, scheduler, loss_fn):
super().__init__()
self.net = net
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.val_losses = []
def training_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
def on_validation_epoch_end(self, *args, **kwargs):
val_loss = np.mean(self.val_losses)
self.val_losses = []
print(f"Validation Loss : {val_loss:0.3f}")
def configure_optimizers(self):
lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1)
trained_model = BoneAgeModel(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()
def predict_bone_age(Radiograph):
img = torch.from_numpy(Radiograph)
img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
img = img / 255. # use same normalization as in the PyTorch dataset
with torch.inference_mode():
bone_age = trained_model.net(img)[0].item()
years = int(bone_age)
months = round((bone_age - years) * 12)
return f"Predicted Bone Age: {years} years, {months} months"
image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
label = gr.Label(show_label=True, label="Bone Age Prediction")
demo = gr.Interface(fn=predict_bone_age,
inputs=[image],
outputs=label)
demo.launch(debug=True)