dillonlaird commited on
Commit
2751f79
1 Parent(s): a0ea772

fixed noise in mask issue

Browse files
app/main.py CHANGED
@@ -208,7 +208,9 @@ async def label_image(image: str, mask_labels: MaskLabels) -> Response:
208
  for i in range(len(mask_labels.masks)):
209
  mask_i = mask_labels.masks[i]
210
  mask_i = mask_i.replace("data:image/png;base64,", "")
211
- save_masks.append(Image.open(io.BytesIO(base64.b64decode(mask_i))).convert("L"))
 
 
212
  bboxes = (
213
  batched_mask_to_box(
214
  torch.as_tensor(np.array([np.array(m) for m in save_masks]))
 
208
  for i in range(len(mask_labels.masks)):
209
  mask_i = mask_labels.masks[i]
210
  mask_i = mask_i.replace("data:image/png;base64,", "")
211
+ mask_i = Image.open(io.BytesIO(base64.b64decode(mask_i))).convert("L")
212
+ mask_i = mask_i.point(lambda p: 0 if <= 1 else p)
213
+ save_masks.append(mask_i)
214
  bboxes = (
215
  batched_mask_to_box(
216
  torch.as_tensor(np.array([np.array(m) for m in save_masks]))
app/per_sam/model.py CHANGED
@@ -8,6 +8,7 @@ import cv2
8
  from torchvision.ops.boxes import batched_nms
9
  from app.mobile_sam import SamPredictor
10
  from app.mobile_sam.utils import batched_mask_to_box
 
11
 
12
 
13
  def point_selection(mask_sim, topk: int = 1):
@@ -139,7 +140,10 @@ def fast_inference(
139
  # Weighted sum three-scale masks
140
  logits_high = logits_high * weights.unsqueeze(-1)
141
  logit_high = logits_high.sum(0)
142
- mask = (logit_high > 0).detach().cpu().numpy()
 
 
 
143
 
144
  logits = logits * weights_np[..., None]
145
  logit = logits.sum(0)
 
8
  from torchvision.ops.boxes import batched_nms
9
  from app.mobile_sam import SamPredictor
10
  from app.mobile_sam.utils import batched_mask_to_box
11
+ from app.sam.postprocessing import clean_mask_torch
12
 
13
 
14
  def point_selection(mask_sim, topk: int = 1):
 
140
  # Weighted sum three-scale masks
141
  logits_high = logits_high * weights.unsqueeze(-1)
142
  logit_high = logits_high.sum(0)
143
+ # mask = (logit_high > 0).detach().cpu().numpy()
144
+
145
+ mask = (logit_high > 0)
146
+ mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy()
147
 
148
  logits = logits * weights_np[..., None]
149
  logit = logits.sum(0)
app/sam/predictor.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from typing import Optional, Tuple
11
  from .sam import Sam
12
  from .transforms import ResizeLongestSide
 
13
 
14
 
15
  class SamPredictor:
@@ -237,6 +238,7 @@ class SamPredictor:
237
 
238
  if not return_logits:
239
  masks = masks > self.model.mask_threshold
 
240
 
241
  return masks, iou_predictions, low_res_masks
242
 
 
10
  from typing import Optional, Tuple
11
  from .sam import Sam
12
  from .transforms import ResizeLongestSide
13
+ from .postprocess import clean_mask_torch
14
 
15
 
16
  class SamPredictor:
 
238
 
239
  if not return_logits:
240
  masks = masks > self.model.mask_threshold
241
+ masks = clean_mask_torch(masks.int()).bool()
242
 
243
  return masks, iou_predictions, low_res_masks
244
 
instance-labeler/app/canvas.tsx CHANGED
@@ -28,7 +28,7 @@ const maskFilter = (imageData: ImageData, color: RGB) => {
28
  const g = imageData.data[i + 1];
29
  const b = imageData.data[i + 2];
30
 
31
- if (r === 0 && g === 0 && b === 0) {
32
  imageData.data[i + 3] = 0;
33
  } else {
34
  imageData.data[i] = color.r;
@@ -261,8 +261,8 @@ export default function Canvas({ imageUrl, imageName }: { imageUrl: string, imag
261
 
262
  const predLabels = async () => {
263
  const length = useLatest ? groupRef.current.length : groupRef.current.length - 1;
264
- if (groupRef.current.length === 0) {
265
- alert('Please pin an instance');
266
  return
267
  }
268
  const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
 
28
  const g = imageData.data[i + 1];
29
  const b = imageData.data[i + 2];
30
 
31
+ if (r <= 1 && g <= 1 && b <= 1) {
32
  imageData.data[i + 3] = 0;
33
  } else {
34
  imageData.data[i] = color.r;
 
261
 
262
  const predLabels = async () => {
263
  const length = useLatest ? groupRef.current.length : groupRef.current.length - 1;
264
+ if (groupRef.current.length === 0 || classList.length === 0) {
265
+ alert('Please pin an instance first');
266
  return
267
  }
268
  const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
requirements.txt CHANGED
@@ -6,3 +6,4 @@ Pillow
6
  fastapi
7
  uvicorn
8
  timm
 
 
6
  fastapi
7
  uvicorn
8
  timm
9
+ kornia