Spaces:
Runtime error
Runtime error
| import torch | |
| from monai.networks.nets import DynUNet | |
| import os | |
| def load_model(model_path="best_model_large_data.pth", device="cpu"): | |
| """Load DynUNet model with weights""" | |
| try: | |
| model = DynUNet( | |
| spatial_dims=2, | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=[3, 3, 3, 3, 3], | |
| strides=[1, 2, 2, 2, 2], | |
| upsample_kernel_size=[2, 2, 2, 2], | |
| filters=[32, 64, 128, 256, 512], | |
| norm_name="INSTANCE", | |
| res_block=True, | |
| deep_supervision=False, | |
| ) | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"❌ Model initialization failed: {e}") | |
| raise | |
| def predict_mask(model, image_tensor): | |
| """Predict segmentation mask with sigmoid activation.""" | |
| try: | |
| if image_tensor.dim() != 4 or image_tensor.shape[1] != 1: | |
| raise ValueError(f"Input tensor must be [1, 1, H, W]. Got {image_tensor.shape}") | |
| with torch.no_grad(): | |
| return torch.sigmoid(model(image_tensor)) | |
| except Exception as e: | |
| print(f"Prediction failed: {e}") | |
| raise | |