RyanMellor commited on
Commit
18b3012
1 Parent(s): ceb4a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -60
app.py CHANGED
@@ -1,68 +1,68 @@
1
- import cv2
2
- import sys
3
- import json
4
- import torch
5
- import warnings
6
- import numpy as np
7
- import streamlit as st
8
- import matplotlib.pyplot as plt
9
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
10
- warnings.filterwarnings('ignore')
11
 
12
- @st.cache_data()
13
- def mask_generate():
14
- '''
15
- Generate mask for image segmentation
16
- '''
17
- sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth"
18
- model_type = "vit_l"
19
- device = "cpu"
20
 
21
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
22
- sam.to(device=device)
23
- mask_generator = SamAutomaticMaskGenerator(sam)
24
- return mask_generator
25
 
26
 
27
- def show_annot(annot, ax):
28
- '''
29
- Show annotations on image
30
- '''
31
- if len(annot) == 0:
32
- return
33
- sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True)
34
- polygons = []
35
- color = []
36
- for ann in sorted_annot:
37
- m = ann['segmentation']
38
- img = np.ones((m.shape[0], m.shape[1], 3))
39
- color_mask = np.random.random((1, 3)).tolist()[0]
40
- for i in range(3):
41
- img[:,:,i] = color_mask[i]
42
- ax.imshow(np.dstack((img, m*0.35)))
43
 
44
- print(torch.cuda.is_available())
45
 
46
- st.title("Segment Anything Model (SAM)")
47
- image_path = st.file_uploader("Upload Image")
48
- if image_path:
49
- with st.spinner("Segmenting image..."):
50
- image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1)
51
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52
 
53
- mask_generator = mask_generate()
54
- masks = mask_generator.generate(image)
55
 
56
- col_original, col_annot = st.columns(2)
57
- with col_original:
58
- st.image(image)
59
- st.caption("Original Image")
60
- with col_annot:
61
- fig, ax = plt.subplots(figsize=(20,20))
62
- ax.imshow(image)
63
- show_annot(masks, ax)
64
- ax.axis('off')
65
- st.pyplot(fig)
66
- st.caption("Output Image")
67
- else:
68
- st.warning('Upload an Image')
 
1
+ import cv2
2
+ import sys
3
+ import json
4
+ import torch
5
+ import warnings
6
+ import numpy as np
7
+ import streamlit as st
8
+ import matplotlib.pyplot as plt
9
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
10
+ warnings.filterwarnings('ignore')
11
 
12
+ @st.cache_data()
13
+ def mask_generate():
14
+ '''
15
+ Generate mask for image segmentation
16
+ '''
17
+ sam_checkpoint = "assets\model\sam_vit_l_0b3195.pth"
18
+ model_type = "vit_l"
19
+ device = "cpu"
20
 
21
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
22
+ sam.to(device=device)
23
+ mask_generator = SamAutomaticMaskGenerator(sam)
24
+ return mask_generator
25
 
26
 
27
+ def show_annot(annot, ax):
28
+ '''
29
+ Show annotations on image
30
+ '''
31
+ if len(annot) == 0:
32
+ return
33
+ sorted_annot = sorted(annot, key=(lambda x: x['area']), reverse=True)
34
+ polygons = []
35
+ color = []
36
+ for ann in sorted_annot:
37
+ m = ann['segmentation']
38
+ img = np.ones((m.shape[0], m.shape[1], 3))
39
+ color_mask = np.random.random((1, 3)).tolist()[0]
40
+ for i in range(3):
41
+ img[:,:,i] = color_mask[i]
42
+ ax.imshow(np.dstack((img, m*0.35)))
43
 
44
+ print(torch.cuda.is_available())
45
 
46
+ st.title("Segment Anything Model (SAM)")
47
+ image_path = st.file_uploader("Upload Image")
48
+ if image_path:
49
+ with st.spinner("Segmenting image..."):
50
+ image = cv2.imdecode(np.fromstring(image_path.read(), np.uint8), 1)
51
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52
 
53
+ mask_generator = mask_generate()
54
+ masks = mask_generator.generate(image)
55
 
56
+ col_original, col_annot = st.columns(2)
57
+ with col_original:
58
+ st.image(image)
59
+ st.caption("Original Image")
60
+ with col_annot:
61
+ fig, ax = plt.subplots(figsize=(20,20))
62
+ ax.imshow(image)
63
+ show_annot(masks, ax)
64
+ ax.axis('off')
65
+ st.pyplot(fig)
66
+ st.caption("Output Image")
67
+ else:
68
+ st.warning('Upload an Image')