ianpan commited on
Commit
a1306dd
·
1 Parent(s): be458f5

Initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +1 -1
  3. app.py +91 -0
  4. fold0.ckpt +3 -0
  5. fold1.ckpt +3 -0
  6. fold2.ckpt +3 -0
  7. requirements.txt +5 -0
.gitattributes CHANGED
@@ -1,6 +1,7 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Bone Age Greulich And Pyle
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: blue
 
1
  ---
2
+ title: Deep Learning Model for Pediatric Bone Age
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: blue
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import timm
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def change_num_input_channels(model, in_channels=1):
8
+ """
9
+ Assumes number of input channels in model is 3.
10
+ """
11
+ for i, m in enumerate(model.modules()):
12
+ if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3:
13
+ m.in_channels = in_channels
14
+ # First, sum across channels
15
+ W = m.weight.sum(1, keepdim=True)
16
+ # Then, divide by number of channels
17
+ W = W / in_channels
18
+ # Then, repeat by number of channels
19
+ size = [1] * W.ndim
20
+ size[1] = in_channels
21
+ W = W.repeat(size)
22
+ m.weight = nn.Parameter(W)
23
+ break
24
+ return model
25
+
26
+
27
+ class Net2D(nn.Module):
28
+
29
+ def __init__(self, weights):
30
+ super().__init__()
31
+ self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0)
32
+ self.backbone = change_num_input_channels(self.backbone, 2)
33
+ self.pool_layer = nn.AdaptiveAvgPool2d(1)
34
+ self.dropout = nn.Dropout(0.2)
35
+ self.classifier = nn.Linear(1280, 1)
36
+ self.load_state_dict(weights)
37
+
38
+ def forward(self, x):
39
+ x = self.backbone(x)
40
+ x = self.pool_layer(x).view(x.size(0), -1)
41
+ x = self.dropout(x)
42
+ x = self.classifier(x)
43
+ return x[:, 0] if x.size(1) == 1 else x
44
+
45
+
46
+ class Ensemble(nn.Module):
47
+
48
+ def __init__(self, model_list):
49
+ super().__init__()
50
+ self.model_list = nn.ModuleList(model_list)
51
+
52
+ def forward(self, x):
53
+ return torch.stack([model(x) for model in self.model_list]).mean(0)
54
+
55
+
56
+ checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
57
+ weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints]
58
+ weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
59
+ models = [Net2D(wt) for wt in weights]
60
+ ensemble = Ensemble(models).eval()
61
+
62
+ def predict_bone_age(Radiograph, Sex):
63
+ img = torch.from_numpy(Radiograph)
64
+ img = img.unsqueeze(0).unsqueeze(0)
65
+ img = img / img.max()
66
+ img = img - 0.5
67
+ img = img * 2.0
68
+ if Sex == 1:
69
+ img = torch.cat([img, torch.zeros_like(img) + 1], dim=1)
70
+ else:
71
+ img = torch.cat([img, torch.zeros_like(img) - 1], dim=1)
72
+ with torch.no_grad():
73
+ bone_age = ensemble(img.float())[0].item()
74
+ return f"Estimated Bone Age: {int(bone_age)} years, {int(bone_age % int(bone_age) * 12)} months"
75
+
76
+
77
+ image = gr.Image(shape=(512, 512), image_mode="L")
78
+ sex = gr.Radio(["Male", "Female"], type="index")
79
+ label = gr.Label(show_label=True, label="Result")
80
+
81
+ demo = gr.Interface(
82
+ fn=predict_bone_age,
83
+ inputs=[image, sex],
84
+ outputs=label,
85
+ )
86
+
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch()
90
+
91
+
fold0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2db6d3fb26a05b916341574c83683017e4a04a1c0df8fda4a97ad2314b33f109
3
+ size 81642981
fold1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c806c2ccd21cb4f1d1102e86d8716ed67583f561d4eea6a1761ac4f9bf6a60b
3
+ size 81642981
fold2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cabdc105bb4c3239d1a57ceaaca4306096a017763c1ec1d23adacf6d8c0713ab
3
+ size 81642981
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ timm
3
+ torch
4
+ https://gradio-main-build.s3.amazonaws.com/e30af8813c3d76329cf4869fa87a902b2075c8cd/gradio-3.8.2-py3-none-any.whl
5
+