Spaces:
Sleeping
Sleeping
from .model import FastSAM | |
import numpy as np | |
from PIL import Image | |
import clip | |
from typing import Optional, List, Tuple, Union | |
class FastSAMDecoder: | |
def __init__( | |
self, | |
model: FastSAM, | |
device: str='cpu', | |
conf: float=0.4, | |
iou: float=0.9, | |
imgsz: int=1024, | |
retina_masks: bool=True, | |
): | |
self.model = model | |
self.device = device | |
self.retina_masks = retina_masks | |
self.imgsz = imgsz | |
self.conf = conf | |
self.iou = iou | |
self.image = None | |
self.image_embedding = None | |
def run_encoder(self, image): | |
if isinstance(image,str): | |
image = np.array(Image.open(image)) | |
self.image = image | |
image_embedding = self.model( | |
self.image, | |
device=self.device, | |
retina_masks=self.retina_masks, | |
imgsz=self.imgsz, | |
conf=self.conf, | |
iou=self.iou | |
) | |
return image_embedding[0].numpy() | |
def run_decoder( | |
self, | |
image_embedding, | |
point_prompt: Optional[np.ndarray]=None, | |
point_label: Optional[np.ndarray]=None, | |
box_prompt: Optional[np.ndarray]=None, | |
text_prompt: Optional[str]=None, | |
)->np.ndarray: | |
self.image_embedding = image_embedding | |
if point_prompt is not None: | |
ann = self.point_prompt(points=point_prompt, pointlabel=point_label) | |
return ann | |
elif box_prompt is not None: | |
ann = self.box_prompt(bbox=box_prompt) | |
return ann | |
elif text_prompt is not None: | |
ann = self.text_prompt(text=text_prompt) | |
return ann | |
else: | |
return None | |
def box_prompt(self, bbox): | |
assert (bbox[2] != 0 and bbox[3] != 0) | |
masks = self.image_embedding.masks.data | |
target_height = self.image.shape[0] | |
target_width = self.image.shape[1] | |
h = masks.shape[1] | |
w = masks.shape[2] | |
if h != target_height or w != target_width: | |
bbox = [ | |
int(bbox[0] * w / target_width), | |
int(bbox[1] * h / target_height), | |
int(bbox[2] * w / target_width), | |
int(bbox[3] * h / target_height), ] | |
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 | |
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 | |
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w | |
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h | |
# IoUs = torch.zeros(len(masks), dtype=torch.float32) | |
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) | |
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) | |
orig_masks_area = np.sum(masks, axis=(1, 2)) | |
union = bbox_area + orig_masks_area - masks_area | |
IoUs = masks_area / union | |
max_iou_index = np.argmax(IoUs) | |
return np.array([masks[max_iou_index].cpu().numpy()]) | |
def point_prompt(self, points, pointlabel): # numpy | |
masks = self._format_results(self.image_embedding[0], 0) | |
target_height = self.image.shape[0] | |
target_width = self.image.shape[1] | |
h = masks[0]['segmentation'].shape[0] | |
w = masks[0]['segmentation'].shape[1] | |
if h != target_height or w != target_width: | |
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] | |
onemask = np.zeros((h, w)) | |
masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
for i, annotation in enumerate(masks): | |
if type(annotation) == dict: | |
mask = annotation['segmentation'] | |
else: | |
mask = annotation | |
for i, point in enumerate(points): | |
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: | |
onemask[mask] = 1 | |
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: | |
onemask[mask] = 0 | |
onemask = onemask >= 1 | |
return np.array([onemask]) | |
def _format_results(self, result, filter=0): | |
annotations = [] | |
n = len(result.masks.data) | |
for i in range(n): | |
annotation = {} | |
mask = result.masks.data[i] == 1.0 | |
if np.sum(mask) < filter: | |
continue | |
annotation['id'] = i | |
annotation['segmentation'] = mask | |
annotation['bbox'] = result.boxes.data[i] | |
annotation['score'] = result.boxes.conf[i] | |
annotation['area'] = annotation['segmentation'].sum() | |
annotations.append(annotation) | |
return annotations | |