import cv2 import sys import json import torch import warnings import numpy as np import streamlit as st # import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor warnings.filterwarnings('ignore') @st.cache_data() def mask_generate(): ''' Generate mask for image segmentation ''' sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth" model_type = "vit_l" device = "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) return mask_generator def show_annot(annot, ax): ''' Show annotations on image ''' if len(annot) == 0: return sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True) polygons = [] color = [] for ann in sorted_annot: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35))) print(torch.cuda.is_available()) st.title("Segment Anything Model (SAM)") image_path = st.file_uploader("Upload Image") if image_path: with st.spinner("Segmenting image..."): image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_generator = mask_generate() masks = mask_generator.generate(image) col_original, col_annot = st.columns(2) with col_original: st.image(image) st.caption("Original Image") with col_annot: fig, ax = plt.subplots(figsize=(20,20)) ax.imshow(image) show_annot(masks, ax) ax.axis('off') st.pyplot(fig) st.caption("Output Image") else: st.warning('Upload an Image')