Spaces:
Sleeping
Sleeping
import cv2 | |
import sys | |
import json | |
import torch | |
import warnings | |
import numpy as np | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
warnings.filterwarnings('ignore') | |
def mask_generate(): | |
''' | |
Generate mask for image segmentation | |
''' | |
sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth" | |
model_type = "vit_l" | |
device = "cpu" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
sam.to(device=device) | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
return mask_generator | |
def show_annot(annot, ax): | |
''' | |
Show annotations on image | |
''' | |
if len(annot) == 0: | |
return | |
sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True) | |
polygons = [] | |
color = [] | |
for ann in sorted_annot: | |
m = ann['segmentation'] | |
img = np.ones((m.shape[0], m.shape[1], 3)) | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
img[:,:,i] = color_mask[i] | |
ax.imshow(np.dstack((img, m*0.35))) | |
print(torch.cuda.is_available()) | |
st.title("Segment Anything Model (SAM)") | |
image_path = st.file_uploader("Upload Image") | |
if image_path: | |
with st.spinner("Segmenting image..."): | |
image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
mask_generator = mask_generate() | |
masks = mask_generator.generate(image) | |
col_original, col_annot = st.columns(2) | |
with col_original: | |
st.image(image) | |
st.caption("Original Image") | |
with col_annot: | |
fig, ax = plt.subplots(figsize=(20,20)) | |
ax.imshow(image) | |
show_annot(masks, ax) | |
ax.axis('off') | |
st.pyplot(fig) | |
st.caption("Output Image") | |
else: | |
st.warning('Upload an Image') |