File size: 1,948 Bytes
18b3012
 
 
 
 
 
 
 
 
 
c032e60
18b3012
 
 
 
 
 
 
 
c032e60
18b3012
 
 
 
c032e60
 
18b3012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c032e60
18b3012
c032e60
18b3012
 
 
 
 
 
c032e60
18b3012
 
c032e60
18b3012
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import cv2
import sys
import json
import torch
import warnings
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
warnings.filterwarnings('ignore')

@st.cache_data()
def mask_generate():
    '''
    Generate mask for image segmentation
    '''
    sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth"
    model_type = "vit_l"
    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    mask_generator = SamAutomaticMaskGenerator(sam)
    return mask_generator


def show_annot(annot, ax):
    '''
    Show annotations on image
    '''
    if len(annot) == 0:
        return
    sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True)
    polygons = []
    color = []
    for ann in sorted_annot:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

print(torch.cuda.is_available())

st.title("Segment Anything Model (SAM)")
image_path = st.file_uploader("Upload Image")
if image_path:
    with st.spinner("Segmenting image..."):
        image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_generator = mask_generate()
        masks = mask_generator.generate(image)

        col_original, col_annot = st.columns(2)
        with col_original:
            st.image(image)
            st.caption("Original Image")
        with col_annot:
            fig, ax = plt.subplots(figsize=(20,20))
            ax.imshow(image)
            show_annot(masks, ax)
            ax.axis('off')
            st.pyplot(fig)
            st.caption("Output Image")
else:
    st.warning('Upload an Image')