import albumentations as A import gradio as gr import numpy as np import torch from albumentations.pytorch import ToTensorV2 from models.model_zoo import BoneAgeEstModelZoo device = "cpu" def initialize_model(): # Load model model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint( "/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt", map_location=device) model.model.eval() model.classifier.eval() model.gender.eval() return model # Preprocessing and postprocessing transform = A.Compose([ A.Resize(width=1024, height=1024), A.CLAHE(), A.Normalize(), ToTensorV2(), ]) def predict(image, gender): if gender: is_female = 1 else: is_female = 0 model = initialize_model() processed_image = transform(image=np.array(image, dtype=np.uint8))['image'] processed_image = processed_image.unsqueeze(0) processed_image = processed_image.to(device) is_female = torch.tensor(is_female).unsqueeze(0).unsqueeze(1).to(device) scans = { 'image': processed_image, 'gender': is_female } preds = model(scans) return int(preds) def run(): image_input = gr.inputs.Image(type="pil", label="Input PNG image") gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=True) output = gr.outputs.Textbox(label="Predicted Age") BAE = gr.Interface( fn=predict, inputs=[image_input, gender_input], outputs=output, ) BAE.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True) if __name__ == "__main__": run()