File size: 2,226 Bytes
a75d95d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512

trained_weights_path = "epoch=009.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]

# recreate the model
class BoneAgeModel(lightning.LightningModule):

  def __init__(self, net, optimizer, scheduler, loss_fn):
    super().__init__()
    self.net = net
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.loss_fn = loss_fn

    self.val_losses = []

  def training_step(self, batch, batch_index):
    out = self.net(batch["x"])
    loss = self.loss_fn(out, batch["y"])
    return loss

  def validation_step(self, batch, batch_index):
    out = self.net(batch["x"])
    loss = self.loss_fn(out, batch["y"])
    self.val_losses.append(loss.item())

  def on_validation_epoch_end(self, *args, **kwargs):
    val_loss = np.mean(self.val_losses)
    self.val_losses = []
    print(f"Validation Loss : {val_loss:0.3f}")

  def configure_optimizers(self):
    lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
    return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}

net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1)
trained_model = BoneAgeModel(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()


def predict_bone_age(Radiograph):
  img = torch.from_numpy(Radiograph)
  img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
  img = img / 255. # use same normalization as in the PyTorch dataset
  with torch.inference_mode():
    bone_age = trained_model.net(img)[0].item()
  years = int(bone_age)
  months = round((bone_age - years) * 12)
  return f"Predicted Bone Age: {years} years, {months} months"


image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
label = gr.Label(show_label=True, label="Bone Age Prediction")

demo = gr.Interface(fn=predict_bone_age,
                    inputs=[image],
                    outputs=label)

demo.launch(debug=True)