getitem / src /segmentation.py
dheena
initial commit
e8ee9c0
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Union, Tuple
import os
import cv2
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
# In[2]:
@dataclass
class BoundingBox:
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
return cls(score=detection_dict['score'],
label=detection_dict['label'],
box=BoundingBox(xmin=detection_dict['box']['xmin'],
ymin=detection_dict['box']['ymin'],
xmax=detection_dict['box']['xmax'],
ymax=detection_dict['box']['ymax']))
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
# Find contours in the binary mask
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour
polygon = largest_contour.reshape(-1, 2).tolist()
return polygon
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
"""
Convert a polygon to a segmentation mask.
Args:
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
- image_shape (tuple): Shape of the image (height, width) for the mask.
Returns:
- np.ndarray: Segmentation mask with the polygon filled.
"""
# Create an empty mask
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def load_image(image_str: str) -> Image.Image:
if image_str.startswith("http"):
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
else:
image = Image.open(image_str).convert("RGB")
return image
def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
boxes = []
for result in results:
xyxy = result.box.xyxy
boxes.append(xyxy)
return [boxes]
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks
# In[6]:
def detect(
image: Image.Image,
labels: List[str],
threshold: float = 0.3,
detector_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
labels = [label if label.endswith(".") else label+"." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
def segment(
image: Image.Image,
detection_results: List[Dict[str, Any]],
polygon_refinement: bool = False,
segmenter_id: Optional[str] = None
) -> List[DetectionResult]:
"""
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
processor = AutoProcessor.from_pretrained(segmenter_id)
boxes = get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
image: Union[Image.Image, str],
labels: List[str],
threshold: float = 0.3,
polygon_refinement: bool = False,
detector_id: Optional[str] = None,
segmenter_id: Optional[str] = None
) -> Tuple[np.ndarray, List[DetectionResult]]:
if isinstance(image, str):
image = load_image(image)
detections = detect(image, labels, threshold, detector_id)
detections = segment(image, detections, polygon_refinement, segmenter_id)
return image, detections
# In[7]:
# save clipped images
def cut_image(image, mask, box):
ny_image = np.array(image)
cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255)
x0, y0, x1, y1 = map(int, box.xyxy)
cropped = cut[y0:y1, x0:x1]
cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
return cropped_bgr