Spaces:
Sleeping
Sleeping
File size: 2,226 Bytes
a75d95d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
trained_weights_path = "epoch=009.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
# recreate the model
class BoneAgeModel(lightning.LightningModule):
def __init__(self, net, optimizer, scheduler, loss_fn):
super().__init__()
self.net = net
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.val_losses = []
def training_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
def on_validation_epoch_end(self, *args, **kwargs):
val_loss = np.mean(self.val_losses)
self.val_losses = []
print(f"Validation Loss : {val_loss:0.3f}")
def configure_optimizers(self):
lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1)
trained_model = BoneAgeModel(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()
def predict_bone_age(Radiograph):
img = torch.from_numpy(Radiograph)
img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
img = img / 255. # use same normalization as in the PyTorch dataset
with torch.inference_mode():
bone_age = trained_model.net(img)[0].item()
years = int(bone_age)
months = round((bone_age - years) * 12)
return f"Predicted Bone Age: {years} years, {months} months"
image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
label = gr.Label(show_label=True, label="Bone Age Prediction")
demo = gr.Interface(fn=predict_bone_age,
inputs=[image],
outputs=label)
demo.launch(debug=True) |