anilbhatt1's picture
Upload 6 files
8fc40b0
raw
history blame
4.72 kB
from .model import FastSAM
import numpy as np
from PIL import Image
import clip
from typing import Optional, List, Tuple, Union
class FastSAMDecoder:
def __init__(
self,
model: FastSAM,
device: str='cpu',
conf: float=0.4,
iou: float=0.9,
imgsz: int=1024,
retina_masks: bool=True,
):
self.model = model
self.device = device
self.retina_masks = retina_masks
self.imgsz = imgsz
self.conf = conf
self.iou = iou
self.image = None
self.image_embedding = None
def run_encoder(self, image):
if isinstance(image,str):
image = np.array(Image.open(image))
self.image = image
image_embedding = self.model(
self.image,
device=self.device,
retina_masks=self.retina_masks,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou
)
return image_embedding[0].numpy()
def run_decoder(
self,
image_embedding,
point_prompt: Optional[np.ndarray]=None,
point_label: Optional[np.ndarray]=None,
box_prompt: Optional[np.ndarray]=None,
text_prompt: Optional[str]=None,
)->np.ndarray:
self.image_embedding = image_embedding
if point_prompt is not None:
ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
return ann
elif box_prompt is not None:
ann = self.box_prompt(bbox=box_prompt)
return ann
elif text_prompt is not None:
ann = self.text_prompt(text=text_prompt)
return ann
else:
return None
def box_prompt(self, bbox):
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.image_embedding.masks.data
target_height = self.image.shape[0]
target_width = self.image.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
orig_masks_area = np.sum(masks, axis=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = np.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy
masks = self._format_results(self.image_embedding[0], 0)
target_height = self.image.shape[0]
target_width = self.image.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return np.array([onemask])
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if np.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations