|
import torch |
|
import numpy as np |
|
from torchvision.transforms import ToTensor |
|
|
|
GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit" |
|
CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit" |
|
|
|
|
|
def load(device: torch.device) -> torch.jit.ScriptModule: |
|
if device.type == "cuda": |
|
model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT) |
|
else: |
|
model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT) |
|
model.eval() |
|
return model |
|
|
|
|
|
def inference_with_box( |
|
image: np.ndarray, |
|
box: np.ndarray, |
|
model: torch.jit.ScriptModule, |
|
device: torch.device |
|
) -> np.ndarray: |
|
bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2]) |
|
bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2]) |
|
img_tensor = ToTensor()(image) |
|
|
|
predicted_logits, predicted_iou = model( |
|
img_tensor[None, ...].to(device), |
|
bbox.to(device), |
|
bbox_labels.to(device), |
|
) |
|
predicted_logits = predicted_logits.cpu() |
|
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy() |
|
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy() |
|
|
|
max_predicted_iou = -1 |
|
selected_mask_using_predicted_iou = None |
|
for m in range(all_masks.shape[0]): |
|
curr_predicted_iou = predicted_iou[m] |
|
if ( |
|
curr_predicted_iou > max_predicted_iou |
|
or selected_mask_using_predicted_iou is None |
|
): |
|
max_predicted_iou = curr_predicted_iou |
|
selected_mask_using_predicted_iou = all_masks[m] |
|
return selected_mask_using_predicted_iou |
|
|
|
|
|
def inference_with_boxes( |
|
image: np.ndarray, |
|
xyxy: np.ndarray, |
|
model: torch.jit.ScriptModule, |
|
device: torch.device |
|
) -> np.ndarray: |
|
masks = [] |
|
for [x_min, y_min, x_max, y_max] in xyxy: |
|
box = np.array([[x_min, y_min], [x_max, y_max]]) |
|
mask = inference_with_box(image, box, model, device) |
|
masks.append(mask) |
|
return np.array(masks) |
|
|