|
import types |
|
|
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
from distinctipy import distinctipy |
|
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor, |
|
sam_model_registry) |
|
from torch.nn import functional as F |
|
|
|
|
|
def get_color(): |
|
return distinctipy.get_colors(200) |
|
|
|
|
|
def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Normalize pixel values and pad to a square input.""" |
|
|
|
x = (x - x.min()) / torch.clip( |
|
x.max() - x.min(), min=1e-8, max=None) |
|
|
|
|
|
h, w = x.shape[-2:] |
|
padh = self.image_encoder.img_size - h |
|
padw = self.image_encoder.img_size - w |
|
x = F.pad(x, (0, padw, 0, padh)) |
|
return x |
|
|
|
|
|
@st.cache_resource |
|
def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = sam_model_registry['vit_b'](checkpoint=checkpoint) |
|
|
|
|
|
funcType = types.MethodType |
|
model.preprocess = funcType(medsam_preprocess, model) |
|
model.mask_threshold = 0.5 |
|
|
|
model = model.to(device) |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
predictor = SamPredictor(model) |
|
mask_generator = SamAutomaticMaskGenerator(model) |
|
return predictor, mask_generator |
|
|
|
|
|
def show_everything(sorted_anns): |
|
if len(sorted_anns) == 0: |
|
return np.array([]) |
|
|
|
h, w = sorted_anns[0]['segmentation'].shape[-2:] |
|
|
|
mask = np.zeros((h,w,4)) |
|
for ann in sorted_anns: |
|
m = ann['segmentation'] |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1) |
|
mask = mask * 255 |
|
return mask.astype(np.uint8) |
|
|
|
|
|
def show_click(masks, colors): |
|
h, w = masks[0].shape[-2:] |
|
masks_total = np.zeros((h,w,4)).astype(np.uint8) |
|
|
|
for mask, color in zip(masks, colors): |
|
if np.array_equal(mask,np.array([])):continue |
|
masks = np.zeros((h,w,4)).astype(np.uint8) |
|
masks = masks + mask.reshape(h,w,1).astype(np.uint8) |
|
masks = masks.astype(bool).astype(np.uint8) |
|
masks = masks * 255 * color.reshape(1, 1, -1) |
|
masks_total += masks.astype(np.uint8) |
|
|
|
return masks_total |
|
|
|
def model_predict_masks_click(model,input_points,input_labels): |
|
if input_points == []:return np.array([]) |
|
input_labels = np.array(input_labels) |
|
input_points = np.array(input_points) |
|
masks, _, _ = model.predict( |
|
point_coords=input_points, |
|
point_labels=input_labels, |
|
multimask_output=False, |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return masks |
|
|
|
def model_predict_masks_box(model,center_point,center_label,input_box): |
|
masks = np.array([]) |
|
for i in range(len(center_label)): |
|
if center_point[i] == []:continue |
|
center_point_1 = np.array([center_point[i]]) |
|
center_label_1 = np.array(center_label[i]) |
|
input_box_1 = np.array(input_box[i]) |
|
mask, _, _ = model.predict( |
|
point_coords=center_point_1, |
|
point_labels=center_label_1, |
|
box=input_box_1, |
|
multimask_output=False, |
|
) |
|
try: |
|
masks = masks + mask |
|
except: |
|
masks = mask |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return masks |
|
|
|
|
|
def model_predict_masks_everything(mask_generator, image): |
|
masks = mask_generator.generate(image) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return masks |