Spaces:
Sleeping
Sleeping
import gradio as gr | |
import lightning | |
import numpy as np | |
import os | |
import pandas as pd | |
import timm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
BACKBONE = "resnet18d" | |
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512 | |
trained_weights_path = "epoch=009.ckpt" | |
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"] | |
# recreate the model | |
class BoneAgeModel(lightning.LightningModule): | |
def __init__(self, net, optimizer, scheduler, loss_fn): | |
super().__init__() | |
self.net = net | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.loss_fn = loss_fn | |
self.val_losses = [] | |
def training_step(self, batch, batch_index): | |
out = self.net(batch["x"]) | |
loss = self.loss_fn(out, batch["y"]) | |
return loss | |
def validation_step(self, batch, batch_index): | |
out = self.net(batch["x"]) | |
loss = self.loss_fn(out, batch["y"]) | |
self.val_losses.append(loss.item()) | |
def on_validation_epoch_end(self, *args, **kwargs): | |
val_loss = np.mean(self.val_losses) | |
self.val_losses = [] | |
print(f"Validation Loss : {val_loss:0.3f}") | |
def configure_optimizers(self): | |
lr_scheduler = {"scheduler": self.scheduler, "interval": "step"} | |
return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler} | |
net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1) | |
trained_model = BoneAgeModel(net, None, None, None) | |
trained_model.load_state_dict(trained_weights) | |
trained_model.eval() | |
def predict_bone_age(Radiograph): | |
img = torch.from_numpy(Radiograph) | |
img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions | |
img = img / 255. # use same normalization as in the PyTorch dataset | |
with torch.inference_mode(): | |
bone_age = trained_model.net(img)[0].item() | |
years = int(bone_age) | |
months = round((bone_age - years) * 12) | |
return f"Predicted Bone Age: {years} years, {months} months" | |
image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale | |
label = gr.Label(show_label=True, label="Bone Age Prediction") | |
demo = gr.Interface(fn=predict_bone_age, | |
inputs=[image], | |
outputs=label) | |
demo.launch(debug=True) |