Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import numpy as np | |
import torch | |
from transformers import SamModel, SamConfig, SamProcessor | |
from PIL import Image | |
DIR = Path(__file__).parent | |
MODEL_PATH = DIR / "www/sidewalk_model_checkpoint_20240427_192251.pth" | |
def _make_prediction_with_bounding_box( | |
model, image: Image, processor, device, prediction_threshold=0.5 | |
): | |
""" | |
Make a prediction using the given model and input image. | |
Returns: | |
predicted_mask: numpy array, representing the predicted mask. | |
predicted_mask_prob: numpy array, representing the predicted mask probabilities. | |
""" | |
# get bounding box prompt | |
prompt = [0, 0, image.width - 20, image.height - 20] | |
# prepare image + box prompt for the model | |
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt") | |
# Move the input tensor to the GPU if it's not already there | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
model.eval() | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
# apply sigmoid | |
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() | |
medsam_seg = (medsam_seg_prob > prediction_threshold).astype(np.uint8) | |
return medsam_seg, medsam_seg_prob | |
def get_sidewalk_prediction(image, prediction_threshold=0.7): | |
mask, mask_prob = _make_prediction_with_bounding_box( | |
model=model, | |
image=image, | |
processor=processor, | |
device=device, | |
prediction_threshold=prediction_threshold, | |
) | |
return mask | |
model = SamModel(config=SamConfig.from_pretrained("facebook/sam-vit-base")) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))) | |
model.to(device) | |
print("Model loaded successfully!") | |