import cv2 import numpy as np import torch import torchvision.transforms as T class SegmentationModel: """The segmentation model.""" def __init__(self, path, device='cuda' if torch.cuda.is_available() else 'cpu'): """ Initializes the segmentation model, with the given path. Args: path: The path to the model file. device: The device to run the model on. """ self.model = torch.load(path, map_location=device) self.model.to(device) self.model.eval() self.device = device @torch.no_grad() def inference(self, data): """Performs inference on the given data. Args: data (dict): A dictionary containing the data to perform inference on. The dictionary should have the following keys: - 'mask': A torch.Tensor of shape (num_masks, height, width) containing the masks for the input image. - 'input': A torch.Tensor of shape (num_masks, 3, height, width) containing the input image. - 'original_shape': A tuple of (height, width) representing the original shape of the input image. Returns: A torch.Tensor of shape (num_masks, height, width) containing the predictions for each mask. """ # Return early if mask is not present if data['mask'] is None: return data, None predictions = [] # Compute predictions for each mask for img in data['input']: prediction = self.model(img[None,:,:,:].to(self.device)).cpu().detach().numpy()[0] predictions.append(prediction) filtered_segmentation = [] bboxes = [] # Apply mask to prediction in order to filter out noise outside of the mask for mask, prediction in zip(data['mask'], predictions): filtered_segmentation.append(mask * prediction) bboxes.append(self.get_bbox_from_mask(mask[0])) filtered_segmentation = torch.cat(filtered_segmentation) # Resize filtered_segmentation to original image size if data['original_shape'][0] > data['original_shape'][1]: resize = T.Resize(data['original_shape'][0], interpolation=T.InterpolationMode.NEAREST) resized_segmentation = resize(filtered_segmentation) resized_segmentation = resized_segmentation[:,:,(resized_segmentation.shape[2] - data['original_shape'][1]) // 2: (resized_segmentation.shape[2] - data['original_shape'][1]) // 2 + data['original_shape'][1]] else: resize = T.Resize(data['original_shape'][1], interpolation=T.InterpolationMode.NEAREST) resized_segmentation = resize(filtered_segmentation) resized_segmentation = resized_segmentation[:, (resized_segmentation.shape[1] - data['original_shape'][0]) // 2: (resized_segmentation.shape[1] - data['original_shape'][0]) // 2 + data['original_shape'][0]] return resized_segmentation def single_inference(self, img, data, transforms=None, transforms_mask=None, desired_size=352): """Performs inference on the given data. Args: img: A numpy array of shape (height, width, 3) containing the image to perform inference on. data: A dictionary containing the data to perform inference on. The dictionary should have the following keys - 'bbox': A list of bounding boxes of shape (num_masks, 4) containing the bounding boxes for the input image. - 'conf': A list of confidence scores of shape (num_masks) containing the confidence scores for the input image. transforms: A torchvision.transforms object to apply to the input image. transforms_mask: A torchvision.transforms object to apply to the mask. desired_size: The desired size of the input image. Returns: A torch.Tensor of shape (num_masks, height, width) containing the predictions for each mask. """ # Return early if mask is not present if len(data['conf']) == 0: return None # Create default transforms if none are provided if transforms is None: transforms = T.Compose([ T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) if transforms_mask is None: transforms_mask = T.Compose([ T.ToTensor(), ]) mask = np.zeros((len(data['bbox']), img.shape[0], img.shape[1])) # Resize image to desired size, and pad with black pixels old_size = img.shape[:2] ratio = float(desired_size) / max(old_size) new_size = tuple([int(x * ratio) for x in old_size]) img = cv2.resize(img, (new_size[1], new_size[0])) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) img = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0,0,0] ) # Create mask for each bounding box new_mask = [] for i, box in enumerate(data['bbox']): box[box < 0] = 0 box = box.int() mask[i, box[1] : box[3], box[0] : box[2]] = 1 new_mask.append( cv2.copyMakeBorder( cv2.resize(mask[i], (new_size[1], new_size[0])), top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0,0,0], ) ) masks = np.array(new_mask) # Apply transforms to image and mask if transforms_mask is not None: transformed_mask = [] for i, mask in enumerate(masks): transformed_mask.append(transforms_mask(mask)) masks = torch.stack(transformed_mask) if transforms is not None: img = transforms(img) # Apply mask to image input_to_model = torch.zeros((len(data['bbox']), 3, img.shape[1], img.shape[2])) for i, mask in enumerate(masks): input_to_model[i] = img * mask # Run inference on image segmentation = self.inference( {'mask': masks, 'original_shape': old_size, 'input': input_to_model, 'data': data}) return segmentation def get_bbox_from_mask(self, mask, mask_value=1): """Returns the bounding box of the given mask.""" # Make sure the mask is a numpy array mask = np.array(mask) # Make sure mask values are non-negative mask[mask < 0] = 0 # Get the indices of the mask that are equal to the mask value if mask_value is None: indices = np.where(mask != 0) else: indices = np.where(mask == mask_value) # Return a zero size box if there are no indices in the mask if indices[0].size <= 0 or indices[1].size <= 0: return np.zeros((4,), dtype=int) # Get the min and max values of the indices min_x = np.min(indices[1]) min_y = np.min(indices[0]) max_x = np.max(indices[1]) max_y = np.max(indices[0]) # Return the bounding box return [min_x, min_y, max_x, max_y]