import albumentations as A import gradio as gr import numpy as np import torch from albumentations.pytorch import ToTensorV2 from models.model_zoo import BoneAgeEstModelZoo device = "cpu" def initialize_model(path): # Load model model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint( path, map_location=device) model.model.eval() model.classifier.eval() model.gender.eval() return model # Preprocessing and postprocessing transform = A.Compose([ A.Resize(width=1024, height=1024), A.CLAHE(), A.Normalize(), ToTensorV2(), ]) def predict(image, gender): preds = [] if gender: is_female = 1 else: is_female = 0 processed_image = transform(image=np.array(image, dtype=np.uint8))['image'] processed_image = processed_image.unsqueeze(0) processed_image = processed_image.to(device) is_female = torch.tensor(is_female).unsqueeze(0).unsqueeze(1).to(device) scans = { 'image': processed_image, 'gender': is_female } path = "/bae/output/inception_1024_new_data/epoch=13-step=9828.ckpt" model = initialize_model(path) preds.append(model(scans)) path = "/bae/output/inception_1024/epoch14_inception_1024_kaggle.ckpt" model = initialize_model(path) preds.append(model(scans)) return int(sum(preds) / len(preds)), int(min(preds)), int(max(preds)) def run(): image_input = gr.inputs.Image(type="pil", label="Input PNG image") gender_input = gr.Checkbox(label="Is Female?", info="Is the scan of a female?", default=False) output = gr.outputs.Textbox(label="Predicted Age") min_range = gr.outputs.Textbox(label="Minimum Predicted Age") max_range = gr.outputs.Textbox(label="Maximum Predicted Age") info = "# Usage Warning! \nThis application is built for research purpose only. It is not intended for clinical use.\nDo not use this application in any commercial or medical setting. \n\n For any information about the project contact: vyawaharest@gmail.com \n\n Test data is available [here](https://huggingface.co/spaces/JackRio/bone-age-estimation/tree/main/test_data)" BAE = gr.Interface( fn=predict, description=info, inputs=[image_input, gender_input], outputs=[output, min_range, max_range] ) BAE.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True, share=True) if __name__ == "__main__": run()