anilbhatt1 commited on
Commit
8fc40b0
β€’
1 Parent(s): 2c9fbb7

Upload 6 files

Browse files
fastsam/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO πŸš€, AGPL-3.0 license
2
+
3
+ from .model import FastSAM
4
+ from .predict import FastSAMPredictor
5
+ from .prompt import FastSAMPrompt
6
+ # from .val import FastSAMValidator
7
+ from .decoder import FastSAMDecoder
8
+
9
+ __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
fastsam/decoder.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import FastSAM
2
+ import numpy as np
3
+ from PIL import Image
4
+ import clip
5
+ from typing import Optional, List, Tuple, Union
6
+
7
+
8
+ class FastSAMDecoder:
9
+ def __init__(
10
+ self,
11
+ model: FastSAM,
12
+ device: str='cpu',
13
+ conf: float=0.4,
14
+ iou: float=0.9,
15
+ imgsz: int=1024,
16
+ retina_masks: bool=True,
17
+ ):
18
+ self.model = model
19
+ self.device = device
20
+ self.retina_masks = retina_masks
21
+ self.imgsz = imgsz
22
+ self.conf = conf
23
+ self.iou = iou
24
+ self.image = None
25
+ self.image_embedding = None
26
+
27
+ def run_encoder(self, image):
28
+ if isinstance(image,str):
29
+ image = np.array(Image.open(image))
30
+ self.image = image
31
+ image_embedding = self.model(
32
+ self.image,
33
+ device=self.device,
34
+ retina_masks=self.retina_masks,
35
+ imgsz=self.imgsz,
36
+ conf=self.conf,
37
+ iou=self.iou
38
+ )
39
+ return image_embedding[0].numpy()
40
+
41
+ def run_decoder(
42
+ self,
43
+ image_embedding,
44
+ point_prompt: Optional[np.ndarray]=None,
45
+ point_label: Optional[np.ndarray]=None,
46
+ box_prompt: Optional[np.ndarray]=None,
47
+ text_prompt: Optional[str]=None,
48
+ )->np.ndarray:
49
+ self.image_embedding = image_embedding
50
+ if point_prompt is not None:
51
+ ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
52
+ return ann
53
+ elif box_prompt is not None:
54
+ ann = self.box_prompt(bbox=box_prompt)
55
+ return ann
56
+ elif text_prompt is not None:
57
+ ann = self.text_prompt(text=text_prompt)
58
+ return ann
59
+ else:
60
+ return None
61
+
62
+ def box_prompt(self, bbox):
63
+ assert (bbox[2] != 0 and bbox[3] != 0)
64
+ masks = self.image_embedding.masks.data
65
+ target_height = self.image.shape[0]
66
+ target_width = self.image.shape[1]
67
+ h = masks.shape[1]
68
+ w = masks.shape[2]
69
+ if h != target_height or w != target_width:
70
+ bbox = [
71
+ int(bbox[0] * w / target_width),
72
+ int(bbox[1] * h / target_height),
73
+ int(bbox[2] * w / target_width),
74
+ int(bbox[3] * h / target_height), ]
75
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
76
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
77
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
78
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
79
+
80
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
81
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
82
+
83
+ masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
84
+ orig_masks_area = np.sum(masks, axis=(1, 2))
85
+
86
+ union = bbox_area + orig_masks_area - masks_area
87
+ IoUs = masks_area / union
88
+ max_iou_index = np.argmax(IoUs)
89
+
90
+ return np.array([masks[max_iou_index].cpu().numpy()])
91
+
92
+ def point_prompt(self, points, pointlabel): # numpy
93
+
94
+ masks = self._format_results(self.image_embedding[0], 0)
95
+ target_height = self.image.shape[0]
96
+ target_width = self.image.shape[1]
97
+ h = masks[0]['segmentation'].shape[0]
98
+ w = masks[0]['segmentation'].shape[1]
99
+ if h != target_height or w != target_width:
100
+ points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
101
+ onemask = np.zeros((h, w))
102
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
103
+ for i, annotation in enumerate(masks):
104
+ if type(annotation) == dict:
105
+ mask = annotation['segmentation']
106
+ else:
107
+ mask = annotation
108
+ for i, point in enumerate(points):
109
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
110
+ onemask[mask] = 1
111
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
112
+ onemask[mask] = 0
113
+ onemask = onemask >= 1
114
+ return np.array([onemask])
115
+
116
+ def _format_results(self, result, filter=0):
117
+ annotations = []
118
+ n = len(result.masks.data)
119
+ for i in range(n):
120
+ annotation = {}
121
+ mask = result.masks.data[i] == 1.0
122
+
123
+ if np.sum(mask) < filter:
124
+ continue
125
+ annotation['id'] = i
126
+ annotation['segmentation'] = mask
127
+ annotation['bbox'] = result.boxes.data[i]
128
+ annotation['score'] = result.boxes.conf[i]
129
+ annotation['area'] = annotation['segmentation'].sum()
130
+ annotations.append(annotation)
131
+ return annotations
fastsam/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO πŸš€, AGPL-3.0 license
2
+ """
3
+ FastSAM model interface.
4
+
5
+ Usage - Predict:
6
+ from ultralytics import FastSAM
7
+
8
+ model = FastSAM('last.pt')
9
+ results = model.predict('ultralytics/assets/bus.jpg')
10
+ """
11
+
12
+ from ultralytics.yolo.cfg import get_cfg
13
+ from ultralytics.yolo.engine.exporter import Exporter
14
+ from ultralytics.yolo.engine.model import YOLO
15
+ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16
+ from ultralytics.yolo.utils.checks import check_imgsz
17
+
18
+ from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
19
+ from .predict import FastSAMPredictor
20
+
21
+
22
+ class FastSAM(YOLO):
23
+
24
+ @smart_inference_mode()
25
+ def predict(self, source=None, stream=False, **kwargs):
26
+ """
27
+ Perform prediction using the YOLO model.
28
+
29
+ Args:
30
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
31
+ Accepts all source types accepted by the YOLO model.
32
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
33
+ **kwargs : Additional keyword arguments passed to the predictor.
34
+ Check the 'configuration' section in the documentation for all available options.
35
+
36
+ Returns:
37
+ (List[ultralytics.yolo.engine.results.Results]): The prediction results.
38
+ """
39
+ if source is None:
40
+ source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
41
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
42
+ overrides = self.overrides.copy()
43
+ overrides['conf'] = 0.25
44
+ overrides.update(kwargs) # prefer kwargs
45
+ overrides['mode'] = kwargs.get('mode', 'predict')
46
+ assert overrides['mode'] in ['track', 'predict']
47
+ overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
48
+ self.predictor = FastSAMPredictor(overrides=overrides)
49
+ self.predictor.setup_model(model=self.model, verbose=False)
50
+ try:
51
+ return self.predictor(source, stream=stream)
52
+ except Exception as e:
53
+ return None
54
+
55
+ def train(self, **kwargs):
56
+ """Function trains models but raises an error as FastSAM models do not support training."""
57
+ raise NotImplementedError("Currently, the training codes are on the way.")
58
+
59
+ def val(self, **kwargs):
60
+ """Run validation given dataset."""
61
+ overrides = dict(task='segment', mode='val')
62
+ overrides.update(kwargs) # prefer kwargs
63
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
64
+ args.imgsz = check_imgsz(args.imgsz, max_dim=1)
65
+ validator = FastSAM(args=args)
66
+ validator(model=self.model)
67
+ self.metrics = validator.metrics
68
+ return validator.metrics
69
+
70
+ @smart_inference_mode()
71
+ def export(self, **kwargs):
72
+ """
73
+ Export model.
74
+
75
+ Args:
76
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
77
+ """
78
+ overrides = dict(task='detect')
79
+ overrides.update(kwargs)
80
+ overrides['mode'] = 'export'
81
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
82
+ args.task = self.task
83
+ if args.imgsz == DEFAULT_CFG.imgsz:
84
+ args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
85
+ if args.batch == DEFAULT_CFG.batch:
86
+ args.batch = 1 # default to 1 if not modified
87
+ return Exporter(overrides=args)(model=self.model)
88
+
89
+ def info(self, detailed=False, verbose=True):
90
+ """
91
+ Logs model info.
92
+
93
+ Args:
94
+ detailed (bool): Show detailed information about model.
95
+ verbose (bool): Controls verbosity.
96
+ """
97
+ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98
+
99
+ def __call__(self, source=None, stream=False, **kwargs):
100
+ """Calls the 'predict' function with given arguments to perform object detection."""
101
+ return self.predict(source, stream, **kwargs)
102
+
103
+ def __getattr__(self, attr):
104
+ """Raises error if object has no requested attribute."""
105
+ name = self.__class__.__name__
106
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
fastsam/predict.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ultralytics.yolo.engine.results import Results
4
+ from ultralytics.yolo.utils import DEFAULT_CFG, ops
5
+ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6
+ from .utils import bbox_iou
7
+
8
+ class FastSAMPredictor(DetectionPredictor):
9
+
10
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11
+ super().__init__(cfg, overrides, _callbacks)
12
+ self.args.task = 'segment'
13
+
14
+ def postprocess(self, preds, img, orig_imgs):
15
+ """TODO: filter by classes."""
16
+ p = ops.non_max_suppression(preds[0],
17
+ self.args.conf,
18
+ self.args.iou,
19
+ agnostic=self.args.agnostic_nms,
20
+ max_det=self.args.max_det,
21
+ nc=len(self.model.names),
22
+ classes=self.args.classes)
23
+
24
+ results = []
25
+ if len(p) == 0 or len(p[0]) == 0:
26
+ print("No object detected.")
27
+ return results
28
+
29
+ full_box = torch.zeros_like(p[0][0])
30
+ full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
31
+ full_box = full_box.view(1, -1)
32
+ critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
33
+ if critical_iou_index.numel() != 0:
34
+ full_box[0][4] = p[0][critical_iou_index][:,4]
35
+ full_box[0][6:] = p[0][critical_iou_index][:,6:]
36
+ p[0][critical_iou_index] = full_box
37
+
38
+ proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
39
+ for i, pred in enumerate(p):
40
+ orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
41
+ path = self.batch[0]
42
+ img_path = path[i] if isinstance(path, list) else path
43
+ if not len(pred): # save empty boxes
44
+ results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
45
+ continue
46
+ if self.args.retina_masks:
47
+ if not isinstance(orig_imgs, torch.Tensor):
48
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
49
+ masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
50
+ else:
51
+ masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
52
+ if not isinstance(orig_imgs, torch.Tensor):
53
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54
+ results.append(
55
+ Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
56
+ return results
fastsam/prompt.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from .utils import image_to_np_ndarray
8
+ from PIL import Image
9
+
10
+ try:
11
+ import clip # for linear_assignment
12
+
13
+ except (ImportError, AssertionError, AttributeError):
14
+ from ultralytics.yolo.utils.checks import check_requirements
15
+
16
+ check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
17
+ import clip
18
+
19
+
20
+ class FastSAMPrompt:
21
+
22
+ def __init__(self, image, results, device='cuda'):
23
+ if isinstance(image, str) or isinstance(image, Image.Image):
24
+ image = image_to_np_ndarray(image)
25
+ self.device = device
26
+ self.results = results
27
+ self.img = image
28
+
29
+ def _segment_image(self, image, bbox):
30
+ if isinstance(image, Image.Image):
31
+ image_array = np.array(image)
32
+ else:
33
+ image_array = image
34
+ segmented_image_array = np.zeros_like(image_array)
35
+ x1, y1, x2, y2 = bbox
36
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
37
+ segmented_image = Image.fromarray(segmented_image_array)
38
+ black_image = Image.new('RGB', image.size, (255, 255, 255))
39
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
40
+ transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
41
+ transparency_mask[y1:y2, x1:x2] = 255
42
+ transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
43
+ black_image.paste(segmented_image, mask=transparency_mask_image)
44
+ return black_image
45
+
46
+ def _format_results(self, result, filter=0):
47
+ annotations = []
48
+ n = len(result.masks.data)
49
+ for i in range(n):
50
+ annotation = {}
51
+ mask = result.masks.data[i] == 1.0
52
+
53
+ if torch.sum(mask) < filter:
54
+ continue
55
+ annotation['id'] = i
56
+ annotation['segmentation'] = mask.cpu().numpy()
57
+ annotation['bbox'] = result.boxes.data[i]
58
+ annotation['score'] = result.boxes.conf[i]
59
+ annotation['area'] = annotation['segmentation'].sum()
60
+ annotations.append(annotation)
61
+ return annotations
62
+
63
+ def filter_masks(annotations): # filte the overlap mask
64
+ annotations.sort(key=lambda x: x['area'], reverse=True)
65
+ to_remove = set()
66
+ for i in range(0, len(annotations)):
67
+ a = annotations[i]
68
+ for j in range(i + 1, len(annotations)):
69
+ b = annotations[j]
70
+ if i != j and j not in to_remove:
71
+ # check if
72
+ if b['area'] < a['area']:
73
+ if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
74
+ to_remove.add(j)
75
+
76
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
77
+
78
+ def _get_bbox_from_mask(self, mask):
79
+ mask = mask.astype(np.uint8)
80
+ contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
81
+ x1, y1, w, h = cv2.boundingRect(contours[0])
82
+ x2, y2 = x1 + w, y1 + h
83
+ if len(contours) > 1:
84
+ for b in contours:
85
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
86
+ # Merge multiple bounding boxes into one.
87
+ x1 = min(x1, x_t)
88
+ y1 = min(y1, y_t)
89
+ x2 = max(x2, x_t + w_t)
90
+ y2 = max(y2, y_t + h_t)
91
+ h = y2 - y1
92
+ w = x2 - x1
93
+ return [x1, y1, x2, y2]
94
+
95
+ def plot_to_result(self,
96
+ annotations,
97
+ bboxes=None,
98
+ points=None,
99
+ point_label=None,
100
+ mask_random_color=True,
101
+ better_quality=True,
102
+ retina=False,
103
+ withContours=True) -> np.ndarray:
104
+ if isinstance(annotations[0], dict):
105
+ annotations = [annotation['segmentation'] for annotation in annotations]
106
+ image = self.img
107
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108
+ original_h = image.shape[0]
109
+ original_w = image.shape[1]
110
+ if sys.platform == "darwin":
111
+ plt.switch_backend("TkAgg")
112
+ plt.figure(figsize=(original_w / 100, original_h / 100))
113
+ # Add subplot with no margin.
114
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
115
+ plt.margins(0, 0)
116
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
117
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
118
+
119
+ plt.imshow(image)
120
+ if better_quality:
121
+ if isinstance(annotations[0], torch.Tensor):
122
+ annotations = np.array(annotations.cpu())
123
+ for i, mask in enumerate(annotations):
124
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
125
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
126
+ if self.device == 'cpu':
127
+ annotations = np.array(annotations)
128
+ self.fast_show_mask(
129
+ annotations,
130
+ plt.gca(),
131
+ random_color=mask_random_color,
132
+ bboxes=bboxes,
133
+ points=points,
134
+ pointlabel=point_label,
135
+ retinamask=retina,
136
+ target_height=original_h,
137
+ target_width=original_w,
138
+ )
139
+ else:
140
+ if isinstance(annotations[0], np.ndarray):
141
+ annotations = torch.from_numpy(annotations)
142
+ self.fast_show_mask_gpu(
143
+ annotations,
144
+ plt.gca(),
145
+ random_color=mask_random_color,
146
+ bboxes=bboxes,
147
+ points=points,
148
+ pointlabel=point_label,
149
+ retinamask=retina,
150
+ target_height=original_h,
151
+ target_width=original_w,
152
+ )
153
+ if isinstance(annotations, torch.Tensor):
154
+ annotations = annotations.cpu().numpy()
155
+ if withContours:
156
+ contour_all = []
157
+ temp = np.zeros((original_h, original_w, 1))
158
+ for i, mask in enumerate(annotations):
159
+ if type(mask) == dict:
160
+ mask = mask['segmentation']
161
+ annotation = mask.astype(np.uint8)
162
+ if not retina:
163
+ annotation = cv2.resize(
164
+ annotation,
165
+ (original_w, original_h),
166
+ interpolation=cv2.INTER_NEAREST,
167
+ )
168
+ contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
169
+ for contour in contours:
170
+ contour_all.append(contour)
171
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
172
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
173
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
174
+ plt.imshow(contour_mask)
175
+
176
+ plt.axis('off')
177
+ fig = plt.gcf()
178
+ plt.draw()
179
+
180
+ try:
181
+ buf = fig.canvas.tostring_rgb()
182
+ except AttributeError:
183
+ fig.canvas.draw()
184
+ buf = fig.canvas.tostring_rgb()
185
+ cols, rows = fig.canvas.get_width_height()
186
+ img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
187
+ result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
188
+ plt.close()
189
+ return result
190
+
191
+ # Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
192
+ def plot(self,
193
+ annotations,
194
+ output_path,
195
+ bboxes=None,
196
+ points=None,
197
+ point_label=None,
198
+ mask_random_color=True,
199
+ better_quality=True,
200
+ retina=False,
201
+ withContours=True):
202
+ if len(annotations) == 0:
203
+ return None
204
+ result = self.plot_to_result(
205
+ annotations,
206
+ bboxes,
207
+ points,
208
+ point_label,
209
+ mask_random_color,
210
+ better_quality,
211
+ retina,
212
+ withContours,
213
+ )
214
+
215
+ path = os.path.dirname(os.path.abspath(output_path))
216
+ if not os.path.exists(path):
217
+ os.makedirs(path)
218
+ result = result[:, :, ::-1]
219
+ cv2.imwrite(output_path, result)
220
+
221
+ # CPU post process
222
+ def fast_show_mask(
223
+ self,
224
+ annotation,
225
+ ax,
226
+ random_color=False,
227
+ bboxes=None,
228
+ points=None,
229
+ pointlabel=None,
230
+ retinamask=True,
231
+ target_height=960,
232
+ target_width=960,
233
+ ):
234
+ msak_sum = annotation.shape[0]
235
+ height = annotation.shape[1]
236
+ weight = annotation.shape[2]
237
+ #Sort annotations based on area.
238
+ areas = np.sum(annotation, axis=(1, 2))
239
+ sorted_indices = np.argsort(areas)
240
+ annotation = annotation[sorted_indices]
241
+
242
+ index = (annotation != 0).argmax(axis=0)
243
+ if random_color:
244
+ color = np.random.random((msak_sum, 1, 1, 3))
245
+ else:
246
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
247
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
248
+ visual = np.concatenate([color, transparency], axis=-1)
249
+ mask_image = np.expand_dims(annotation, -1) * visual
250
+
251
+ show = np.zeros((height, weight, 4))
252
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
253
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
254
+ # Use vectorized indexing to update the values of 'show'.
255
+ show[h_indices, w_indices, :] = mask_image[indices]
256
+ if bboxes is not None:
257
+ for bbox in bboxes:
258
+ x1, y1, x2, y2 = bbox
259
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
260
+ # draw point
261
+ if points is not None:
262
+ plt.scatter(
263
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
264
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
265
+ s=20,
266
+ c='y',
267
+ )
268
+ plt.scatter(
269
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
270
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
271
+ s=20,
272
+ c='m',
273
+ )
274
+
275
+ if not retinamask:
276
+ show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
277
+ ax.imshow(show)
278
+
279
+ def fast_show_mask_gpu(
280
+ self,
281
+ annotation,
282
+ ax,
283
+ random_color=False,
284
+ bboxes=None,
285
+ points=None,
286
+ pointlabel=None,
287
+ retinamask=True,
288
+ target_height=960,
289
+ target_width=960,
290
+ ):
291
+ msak_sum = annotation.shape[0]
292
+ height = annotation.shape[1]
293
+ weight = annotation.shape[2]
294
+ areas = torch.sum(annotation, dim=(1, 2))
295
+ sorted_indices = torch.argsort(areas, descending=False)
296
+ annotation = annotation[sorted_indices]
297
+ # Find the index of the first non-zero value at each position.
298
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
299
+ if random_color:
300
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
301
+ else:
302
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
303
+ 30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
304
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
305
+ visual = torch.cat([color, transparency], dim=-1)
306
+ mask_image = torch.unsqueeze(annotation, -1) * visual
307
+ # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
308
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
309
+ try:
310
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
311
+ except:
312
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
313
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
314
+ # Use vectorized indexing to update the values of 'show'.
315
+ show[h_indices, w_indices, :] = mask_image[indices]
316
+ show_cpu = show.cpu().numpy()
317
+ if bboxes is not None:
318
+ for bbox in bboxes:
319
+ x1, y1, x2, y2 = bbox
320
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
321
+ # draw point
322
+ if points is not None:
323
+ plt.scatter(
324
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
325
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
326
+ s=20,
327
+ c='y',
328
+ )
329
+ plt.scatter(
330
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
331
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
332
+ s=20,
333
+ c='m',
334
+ )
335
+ if not retinamask:
336
+ show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
337
+ ax.imshow(show_cpu)
338
+
339
+ # clip
340
+ @torch.no_grad()
341
+ def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
342
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
343
+ tokenized_text = clip.tokenize([search_text]).to(device)
344
+ stacked_images = torch.stack(preprocessed_images)
345
+ image_features = model.encode_image(stacked_images)
346
+ text_features = model.encode_text(tokenized_text)
347
+ image_features /= image_features.norm(dim=-1, keepdim=True)
348
+ text_features /= text_features.norm(dim=-1, keepdim=True)
349
+ probs = 100.0 * image_features @ text_features.T
350
+ return probs[:, 0].softmax(dim=0)
351
+
352
+ def _crop_image(self, format_results):
353
+
354
+ image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
355
+ ori_w, ori_h = image.size
356
+ annotations = format_results
357
+ mask_h, mask_w = annotations[0]['segmentation'].shape
358
+ if ori_w != mask_w or ori_h != mask_h:
359
+ image = image.resize((mask_w, mask_h))
360
+ cropped_boxes = []
361
+ cropped_images = []
362
+ not_crop = []
363
+ filter_id = []
364
+ # annotations, _ = filter_masks(annotations)
365
+ # filter_id = list(_)
366
+ for _, mask in enumerate(annotations):
367
+ if np.sum(mask['segmentation']) <= 100:
368
+ filter_id.append(_)
369
+ continue
370
+ bbox = self._get_bbox_from_mask(mask['segmentation']) # mask ηš„ bbox
371
+ cropped_boxes.append(self._segment_image(image, bbox))
372
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
373
+ cropped_images.append(bbox) # Save the bounding box of the cropped image.
374
+
375
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
376
+
377
+ def box_prompt(self, bbox=None, bboxes=None):
378
+ if self.results == None:
379
+ return []
380
+ assert bbox or bboxes
381
+ if bboxes is None:
382
+ bboxes = [bbox]
383
+ max_iou_index = []
384
+ for bbox in bboxes:
385
+ assert (bbox[2] != 0 and bbox[3] != 0)
386
+ masks = self.results[0].masks.data
387
+ target_height = self.img.shape[0]
388
+ target_width = self.img.shape[1]
389
+ h = masks.shape[1]
390
+ w = masks.shape[2]
391
+ if h != target_height or w != target_width:
392
+ bbox = [
393
+ int(bbox[0] * w / target_width),
394
+ int(bbox[1] * h / target_height),
395
+ int(bbox[2] * w / target_width),
396
+ int(bbox[3] * h / target_height), ]
397
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
398
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
399
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
400
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
401
+
402
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
403
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
404
+
405
+ masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
406
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
407
+
408
+ union = bbox_area + orig_masks_area - masks_area
409
+ IoUs = masks_area / union
410
+ max_iou_index.append(int(torch.argmax(IoUs)))
411
+ max_iou_index = list(set(max_iou_index))
412
+ return np.array(masks[max_iou_index].cpu().numpy())
413
+
414
+ def point_prompt(self, points, pointlabel): # numpy
415
+ if self.results == None:
416
+ return []
417
+ masks = self._format_results(self.results[0], 0)
418
+ target_height = self.img.shape[0]
419
+ target_width = self.img.shape[1]
420
+ h = masks[0]['segmentation'].shape[0]
421
+ w = masks[0]['segmentation'].shape[1]
422
+ if h != target_height or w != target_width:
423
+ points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
424
+ onemask = np.zeros((h, w))
425
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
426
+ for i, annotation in enumerate(masks):
427
+ if type(annotation) == dict:
428
+ mask = annotation['segmentation']
429
+ else:
430
+ mask = annotation
431
+ for i, point in enumerate(points):
432
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
433
+ onemask[mask] = 1
434
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
435
+ onemask[mask] = 0
436
+ onemask = onemask >= 1
437
+ return np.array([onemask])
438
+
439
+ def text_prompt(self, text):
440
+ if self.results == None:
441
+ return []
442
+ format_results = self._format_results(self.results[0], 0)
443
+ cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
444
+ clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
445
+ scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
446
+ max_idx = scores.argsort()
447
+ max_idx = max_idx[-1]
448
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
449
+ return np.array([annotations[max_idx]['segmentation']])
450
+
451
+ def everything_prompt(self):
452
+ if self.results == None:
453
+ return []
454
+ return self.results[0].masks.data
455
+
fastsam/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+
6
+ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7
+ '''Adjust bounding boxes to stick to image border if they are within a certain threshold.
8
+ Args:
9
+ boxes: (n, 4)
10
+ image_shape: (height, width)
11
+ threshold: pixel threshold
12
+ Returns:
13
+ adjusted_boxes: adjusted bounding boxes
14
+ '''
15
+
16
+ # Image dimensions
17
+ h, w = image_shape
18
+
19
+ # Adjust boxes
20
+ boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
21
+ 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
22
+ boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
23
+ 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
24
+ boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
25
+ w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
26
+ boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
27
+ h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
28
+
29
+ return boxes
30
+
31
+
32
+
33
+ def convert_box_xywh_to_xyxy(box):
34
+ x1 = box[0]
35
+ y1 = box[1]
36
+ x2 = box[0] + box[2]
37
+ y2 = box[1] + box[3]
38
+ return [x1, y1, x2, y2]
39
+
40
+
41
+ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
42
+ '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
43
+ Args:
44
+ box1: (4, )
45
+ boxes: (n, 4)
46
+ Returns:
47
+ high_iou_indices: Indices of boxes with IoU > thres
48
+ '''
49
+ boxes = adjust_bboxes_to_image_border(boxes, image_shape)
50
+ # obtain coordinates for intersections
51
+ x1 = torch.max(box1[0], boxes[:, 0])
52
+ y1 = torch.max(box1[1], boxes[:, 1])
53
+ x2 = torch.min(box1[2], boxes[:, 2])
54
+ y2 = torch.min(box1[3], boxes[:, 3])
55
+
56
+ # compute the area of intersection
57
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
58
+
59
+ # compute the area of both individual boxes
60
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
61
+ box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
62
+
63
+ # compute the area of union
64
+ union = box1_area + box2_area - intersection
65
+
66
+ # compute the IoU
67
+ iou = intersection / union # Should be shape (n, )
68
+ if raw_output:
69
+ if iou.numel() == 0:
70
+ return 0
71
+ return iou
72
+
73
+ # get indices of boxes with IoU > thres
74
+ high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
75
+
76
+ return high_iou_indices
77
+
78
+
79
+ def image_to_np_ndarray(image):
80
+ if type(image) is str:
81
+ return np.array(Image.open(image))
82
+ elif issubclass(type(image), Image.Image):
83
+ return np.array(image)
84
+ elif type(image) is np.ndarray:
85
+ return image
86
+ return None