Spaces:
Runtime error
Runtime error
import albumentations as A | |
import gradio as gr | |
import numpy as np | |
import torch | |
from albumentations.pytorch import ToTensorV2 | |
from models.model_zoo import BoneAgeEstModelZoo | |
device = "cpu" | |
def initialize_model(): | |
# Load model | |
model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint( | |
"/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt", map_location=device) | |
model.model.eval() | |
model.classifier.eval() | |
model.gender.eval() | |
return model | |
# Preprocessing and postprocessing | |
transform = A.Compose([ | |
A.Resize(width=1024, height=1024), | |
A.CLAHE(), | |
A.Normalize(), | |
ToTensorV2(), | |
]) | |
def predict(image, gender): | |
if gender: | |
is_female = 1 | |
else: | |
is_female = 0 | |
model = initialize_model() | |
processed_image = transform(image=np.array(image, dtype=np.uint8))['image'] | |
processed_image = processed_image.unsqueeze(0) | |
processed_image = processed_image.to(device) | |
is_female = torch.tensor(is_female).unsqueeze(0).unsqueeze(1).to(device) | |
scans = { | |
'image': processed_image, | |
'gender': is_female | |
} | |
preds = model(scans) | |
return int(preds) | |
def run(): | |
image_input = gr.inputs.Image(type="pil", label="Input PNG image") | |
gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=True) | |
output = gr.outputs.Textbox(label="Predicted Age") | |
BAE = gr.Interface( | |
fn=predict, | |
inputs=[image_input, gender_input], | |
outputs=output, | |
) | |
BAE.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True) | |
if __name__ == "__main__": | |
run() | |