haritsahm
Fix when no mask result found
4e097eb
import types
import numpy as np
import streamlit as st
import torch
from distinctipy import distinctipy
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
sam_model_registry)
from torch.nn import functional as F
def get_color():
return distinctipy.get_colors(200)
def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - x.min()) / torch.clip(
x.max() - x.min(), min=1e-8, max=None) # normalize to [0, 1], (H, W, 3)
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
@st.cache_resource
def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = sam_model_registry['vit_b'](checkpoint=checkpoint)
# Replace preprocess function
funcType = types.MethodType
model.preprocess = funcType(medsam_preprocess, model)
model.mask_threshold = 0.5
model = model.to(device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
predictor = SamPredictor(model)
mask_generator = SamAutomaticMaskGenerator(model)
return predictor, mask_generator
def show_everything(sorted_anns):
if len(sorted_anns) == 0:
return np.array([])
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
h, w = sorted_anns[0]['segmentation'].shape[-2:]
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
mask = np.zeros((h,w,4))
for ann in sorted_anns:
m = ann['segmentation']
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
mask = mask * 255
return mask.astype(np.uint8)
def show_click(masks, colors):
h, w = masks[0].shape[-2:]
masks_total = np.zeros((h,w,4)).astype(np.uint8)
for mask, color in zip(masks, colors):
if np.array_equal(mask,np.array([])):continue
masks = np.zeros((h,w,4)).astype(np.uint8)
masks = masks + mask.reshape(h,w,1).astype(np.uint8)
masks = masks.astype(bool).astype(np.uint8)
masks = masks * 255 * color.reshape(1, 1, -1)
masks_total += masks.astype(np.uint8)
return masks_total
def model_predict_masks_click(model,input_points,input_labels):
if input_points == []:return np.array([])
input_labels = np.array(input_labels)
input_points = np.array(input_points)
masks, _, _ = model.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks
def model_predict_masks_box(model,center_point,center_label,input_box):
masks = np.array([])
for i in range(len(center_label)):
if center_point[i] == []:continue
center_point_1 = np.array([center_point[i]])
center_label_1 = np.array(center_label[i])
input_box_1 = np.array(input_box[i])
mask, _, _ = model.predict(
point_coords=center_point_1,
point_labels=center_label_1,
box=input_box_1,
multimask_output=False,
)
try:
masks = masks + mask
except:
masks = mask
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks
def model_predict_masks_everything(mask_generator, image):
masks = mask_generator.generate(image)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks