StreamlitSAM / app.py
RyanMellor's picture
Update app.py
18b3012
raw
history blame
1.95 kB
import cv2
import sys
import json
import torch
import warnings
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
warnings.filterwarnings('ignore')
@st.cache_data()
def mask_generate():
'''
Generate mask for image segmentation
'''
sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth"
model_type = "vit_l"
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator
def show_annot(annot, ax):
'''
Show annotations on image
'''
if len(annot) == 0:
return
sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True)
polygons = []
color = []
for ann in sorted_annot:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
print(torch.cuda.is_available())
st.title("Segment Anything Model (SAM)")
image_path = st.file_uploader("Upload Image")
if image_path:
with st.spinner("Segmenting image..."):
image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = mask_generate()
masks = mask_generator.generate(image)
col_original, col_annot = st.columns(2)
with col_original:
st.image(image)
st.caption("Original Image")
with col_annot:
fig, ax = plt.subplots(figsize=(20,20))
ax.imshow(image)
show_annot(masks, ax)
ax.axis('off')
st.pyplot(fig)
st.caption("Output Image")
else:
st.warning('Upload an Image')