import albumentations as A
import gradio as gr
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from models.model_zoo import BoneAgeEstModelZoo
device = "cpu"
def initialize_model(path):
# Load model
model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint(
path, map_location=device)
model.model.eval()
model.classifier.eval()
model.gender.eval()
return model
# Preprocessing and postprocessing
transform = A.Compose([
A.Resize(width=1024, height=1024),
A.CLAHE(),
A.Normalize(),
ToTensorV2(),
])
def predict(image, gender):
preds = []
if gender:
is_female = 1
else:
is_female = 0
processed_image = transform(image=np.array(image, dtype=np.uint8))['image']
processed_image = processed_image.unsqueeze(0)
processed_image = processed_image.to(device)
is_female = torch.tensor(is_female).unsqueeze(0).unsqueeze(1).to(device)
scans = {
'image': processed_image,
'gender': is_female
}
path = "/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt"
model = initialize_model(path)
preds.append(model(scans))
path = "/bae/output/inception_1024/epoch14_inception_1024_kaggle.ckpt"
model = initialize_model(path)
preds.append(model(scans))
return int(sum(preds) / len(preds)), int(min(preds)), int(max(preds))
def run():
image_input = gr.inputs.Image(type="pil", label="Input PNG image")
gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=False)
output = gr.outputs.Textbox(label="Predicted Age")
min_range = gr.outputs.Textbox(label="Minimum Predicted Age")
max_range = gr.outputs.Textbox(label="Maximum Predicted Age")
info = "# Usage Warning! \nThis application is built for research purpose only. It is not intended for clinical use.\nDo not use this application in any commercial or medical setting. \n\n For any information about the project contact: vyawaharest@gmail.com \n\n Test data is available [here](https://huggingface.co/spaces/JackRio/bone-age-estimation/tree/main/test_data)"
BAE = gr.Interface(
fn=predict,
description=info,
inputs=[image_input, gender_input],
outputs=[output, min_range, max_range]
)
BAE.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True, share=True)
if __name__ == "__main__":
run()