from download import attempt_download_from_hub import segmentation_models_pytorch as smp from dataloader import * import torch def unet_prediction(input_path, model_path): model_path = attempt_download_from_hub(model_path) best_model = torch.load(model_path) preprocessing_fn = smp.encoders.get_preprocessing_fn('efficientnet-b6', 'imagenet') test_dataset = Dataset(input_path, augmentation=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) image = test_dataset.get() x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0) pr_mask = best_model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round())*255 # Save the predicted mask cv2.imwrite("output.png", pr_mask) return 'output.png'