UpToDataScience's picture
project 1 milestone 4 - project video
c3b5732
from pathlib import Path
import numpy as np
import torch
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
DIR = Path(__file__).parent
MODEL_PATH = DIR / "www/sidewalk_model_checkpoint_20240427_192251.pth"
def _make_prediction_with_bounding_box(
model, image: Image, processor, device, prediction_threshold=0.5
):
"""
Make a prediction using the given model and input image.
Returns:
predicted_mask: numpy array, representing the predicted mask.
predicted_mask_prob: numpy array, representing the predicted mask probabilities.
"""
# get bounding box prompt
prompt = [0, 0, image.width - 20, image.height - 20]
# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt")
# Move the input tensor to the GPU if it's not already there
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
# forward pass
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > prediction_threshold).astype(np.uint8)
return medsam_seg, medsam_seg_prob
def get_sidewalk_prediction(image, prediction_threshold=0.7):
mask, mask_prob = _make_prediction_with_bounding_box(
model=model,
image=image,
processor=processor,
device=device,
prediction_threshold=prediction_threshold,
)
return mask
model = SamModel(config=SamConfig.from_pretrained("facebook/sam-vit-base"))
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
model.to(device)
print("Model loaded successfully!")