JackRio commited on
Commit
83e02d3
·
1 Parent(s): c538073

Multiple model

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. predict.py +12 -7
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM docker.io/jackrio/bae_repo:v3
2
 
3
  USER root
4
 
 
1
+ FROM docker.io/jackrio/bae_repo:v4
2
 
3
  USER root
4
 
predict.py CHANGED
@@ -7,10 +7,10 @@ from models.model_zoo import BoneAgeEstModelZoo
7
  device = "cpu"
8
 
9
 
10
- def initialize_model():
11
  # Load model
12
  model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint(
13
- "/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt", map_location=device)
14
  model.model.eval()
15
  model.classifier.eval()
16
  model.gender.eval()
@@ -31,8 +31,8 @@ def predict(image, gender):
31
  is_female = 1
32
  else:
33
  is_female = 0
34
-
35
- model = initialize_model()
36
 
37
  processed_image = transform(image=np.array(image, dtype=np.uint8))['image']
38
  processed_image = processed_image.unsqueeze(0)
@@ -43,13 +43,18 @@ def predict(image, gender):
43
  'image': processed_image,
44
  'gender': is_female
45
  }
46
- preds = model(scans)
47
- return int(preds)
 
 
 
 
 
48
 
49
 
50
  def run():
51
  image_input = gr.inputs.Image(type="pil", label="Input PNG image")
52
- gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=True)
53
  output = gr.outputs.Textbox(label="Predicted Age")
54
 
55
  BAE = gr.Interface(
 
7
  device = "cpu"
8
 
9
 
10
+ def initialize_model(path):
11
  # Load model
12
  model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint(
13
+ path, map_location=device)
14
  model.model.eval()
15
  model.classifier.eval()
16
  model.gender.eval()
 
31
  is_female = 1
32
  else:
33
  is_female = 0
34
+ path = "/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt"
35
+ model = initialize_model(path)
36
 
37
  processed_image = transform(image=np.array(image, dtype=np.uint8))['image']
38
  processed_image = processed_image.unsqueeze(0)
 
43
  'image': processed_image,
44
  'gender': is_female
45
  }
46
+ preds_1 = model(scans)
47
+
48
+ path = "/bae/output/inception_1024/epoch14_inception_1024_kaggle.ckpt"
49
+ model = initialize_model(path)
50
+ preds_2 = model(scans)
51
+
52
+ return int((preds_1 + preds_2) / 2)
53
 
54
 
55
  def run():
56
  image_input = gr.inputs.Image(type="pil", label="Input PNG image")
57
+ gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=False)
58
  output = gr.outputs.Textbox(label="Predicted Age")
59
 
60
  BAE = gr.Interface(