Spaces:
Build error
Build error
| import gradio as gr | |
| import lightning | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| BACKBONE = "resnet18d" | |
| IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512 | |
| trained_weights_path = "epoch=009.ckpt" | |
| trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"] | |
| # recreate the model | |
| class BoneAgeModel(lightning.LightningModule): | |
| def __init__(self, net, optimizer, scheduler, loss_fn): | |
| super().__init__() | |
| self.net = net | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| self.loss_fn = loss_fn | |
| self.val_losses = [] | |
| def training_step(self, batch, batch_index): | |
| out = self.net(batch["x"]) | |
| loss = self.loss_fn(out, batch["y"]) | |
| return loss | |
| def validation_step(self, batch, batch_index): | |
| out = self.net(batch["x"]) | |
| loss = self.loss_fn(out, batch["y"]) | |
| self.val_losses.append(loss.item()) | |
| def on_validation_epoch_end(self, *args, **kwargs): | |
| val_loss = np.mean(self.val_losses) | |
| self.val_losses = [] | |
| print(f"Validation Loss : {val_loss:0.3f}") | |
| def configure_optimizers(self): | |
| lr_scheduler = {"scheduler": self.scheduler, "interval": "step"} | |
| return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler} | |
| net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1) | |
| trained_model = BoneAgeModel(net, None, None, None) | |
| trained_model.load_state_dict(trained_weights) | |
| trained_model.eval() | |
| def predict_bone_age(Radiograph): | |
| img = torch.from_numpy(Radiograph) | |
| img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions | |
| img = img / 255. # use same normalization as in the PyTorch dataset | |
| with torch.inference_mode(): | |
| bone_age = trained_model.net(img)[0].item() | |
| years = int(bone_age) | |
| months = round((bone_age - years) * 12) | |
| return f"Predicted Bone Age: {years} years, {months} months" | |
| image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale | |
| label = gr.Label(show_label=True, label="Bone Age Prediction") | |
| demo = gr.Interface(fn=predict_bone_age, | |
| inputs=[image], | |
| outputs=label) | |
| demo.launch(debug=True) |