bone-age-estimation / predict.py
JackRio's picture
Removing additional comma
c538073
raw
history blame
1.67 kB
import albumentations as A
import gradio as gr
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from models.model_zoo import BoneAgeEstModelZoo
device = "cpu"
def initialize_model():
# Load model
model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint(
"/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt", map_location=device)
model.model.eval()
model.classifier.eval()
model.gender.eval()
return model
# Preprocessing and postprocessing
transform = A.Compose([
A.Resize(width=1024, height=1024),
A.CLAHE(),
A.Normalize(),
ToTensorV2(),
])
def predict(image, gender):
if gender:
is_female = 1
else:
is_female = 0
model = initialize_model()
processed_image = transform(image=np.array(image, dtype=np.uint8))['image']
processed_image = processed_image.unsqueeze(0)
processed_image = processed_image.to(device)
is_female = torch.tensor(is_female).unsqueeze(0).unsqueeze(1).to(device)
scans = {
'image': processed_image,
'gender': is_female
}
preds = model(scans)
return int(preds)
def run():
image_input = gr.inputs.Image(type="pil", label="Input PNG image")
gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=True)
output = gr.outputs.Textbox(label="Predicted Age")
BAE = gr.Interface(
fn=predict,
inputs=[image_input, gender_input],
outputs=output,
)
BAE.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)
if __name__ == "__main__":
run()