felipekitamura commited on
Commit
a75d95d
1 Parent(s): 4630474

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import lightning
3
+ import numpy as np
4
+ import os
5
+ import pandas as pd
6
+ import timm
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ BACKBONE = "resnet18d"
13
+ IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
14
+
15
+ trained_weights_path = "epoch=009.ckpt"
16
+ trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
17
+
18
+ # recreate the model
19
+ class BoneAgeModel(lightning.LightningModule):
20
+
21
+ def __init__(self, net, optimizer, scheduler, loss_fn):
22
+ super().__init__()
23
+ self.net = net
24
+ self.optimizer = optimizer
25
+ self.scheduler = scheduler
26
+ self.loss_fn = loss_fn
27
+
28
+ self.val_losses = []
29
+
30
+ def training_step(self, batch, batch_index):
31
+ out = self.net(batch["x"])
32
+ loss = self.loss_fn(out, batch["y"])
33
+ return loss
34
+
35
+ def validation_step(self, batch, batch_index):
36
+ out = self.net(batch["x"])
37
+ loss = self.loss_fn(out, batch["y"])
38
+ self.val_losses.append(loss.item())
39
+
40
+ def on_validation_epoch_end(self, *args, **kwargs):
41
+ val_loss = np.mean(self.val_losses)
42
+ self.val_losses = []
43
+ print(f"Validation Loss : {val_loss:0.3f}")
44
+
45
+ def configure_optimizers(self):
46
+ lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
47
+ return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
48
+
49
+ net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1)
50
+ trained_model = BoneAgeModel(net, None, None, None)
51
+ trained_model.load_state_dict(trained_weights)
52
+ trained_model.eval()
53
+
54
+
55
+ def predict_bone_age(Radiograph):
56
+ img = torch.from_numpy(Radiograph)
57
+ img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
58
+ img = img / 255. # use same normalization as in the PyTorch dataset
59
+ with torch.inference_mode():
60
+ bone_age = trained_model.net(img)[0].item()
61
+ years = int(bone_age)
62
+ months = round((bone_age - years) * 12)
63
+ return f"Predicted Bone Age: {years} years, {months} months"
64
+
65
+
66
+ image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
67
+ label = gr.Label(show_label=True, label="Bone Age Prediction")
68
+
69
+ demo = gr.Interface(fn=predict_bone_age,
70
+ inputs=[image],
71
+ outputs=label)
72
+
73
+ demo.launch(debug=True)