cmseibold commited on
Commit
9cd2e5b
1 Parent(s): ab746bd

Upload 3 files

Browse files
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import subprocess
3
+ import os
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+
8
+ os.environ["CXAS_PATH"] = "./weights"
9
+
10
+
11
+ os.makedirs("tmp", exist_ok=True)
12
+
13
+ # Helper function to run the segmentation command
14
+ def run_segmentation(input_image_path, output_folder, mode="segment", gpu="cpu"):
15
+ command = f"cxas -i {input_image_path} -o {output_folder} --mode {mode} -g {gpu} -s"
16
+ subprocess.run(command, shell=True)
17
+ return output_folder
18
+
19
+ # Helper function to colorize and outline the binary mask
20
+ def colorize_and_outline_mask(mask_image, color=(0, 255, 0)):
21
+ mask_np = np.array(mask_image.convert("L")) # Ensure it is a grayscale image
22
+ _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
23
+ colorized_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
24
+ colorized_mask[mask_np == 255] = color # Apply the color to mask regions
25
+ edges = cv2.Canny(mask_np, 100, 200) # Detect edges
26
+ colorized_mask[edges == 255] = [255, 255, 255] # Highlight the edges
27
+ return colorized_mask
28
+
29
+ # Helper function to overlay mask on the image
30
+ def overlay_mask_on_image(input_image, mask_image, alpha=0.5):
31
+ input_image_np = np.array(input_image)
32
+ if len(input_image_np.shape) == 2: # Convert grayscale to RGB
33
+ input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
34
+ mask_image_resized = cv2.resize(mask_image, (input_image_np.shape[1], input_image_np.shape[0]))
35
+ overlayed_image = cv2.addWeighted(input_image_np, 1-alpha, mask_image_resized, alpha, 0)
36
+ return overlayed_image
37
+
38
+ # Streamlit app
39
+ st.title("Image Segmentation Tool")
40
+
41
+ # Check if session state is initialized
42
+ if "input_image" not in st.session_state:
43
+ st.session_state.input_image = None
44
+ st.session_state.output_folder = None
45
+ st.session_state.mask_files = []
46
+ st.session_state.segmentation_done = False
47
+ st.session_state.selected_mask = None # Store selected mask in session state
48
+
49
+ # File uploader for user to input image
50
+ uploaded_image = st.file_uploader("Upload an image file", type=["png", "jpg", "jpeg"])
51
+
52
+ # If a new image is uploaded, reset the session state
53
+ if uploaded_image is not None:
54
+ if not os.path.isdir(os.path.join("tmp/output", os.path.splitext(uploaded_image.name)[0])):
55
+ os.makedirs("tmp", exist_ok=True)
56
+ st.session_state.input_image = Image.open(uploaded_image) # Store the image in session state
57
+ input_image_path = f"tmp/{uploaded_image.name}"
58
+ st.session_state.input_image.save(input_image_path)
59
+
60
+ input_image_name = os.path.splitext(uploaded_image.name)[0]
61
+ output_folder = os.path.join("tmp/output")
62
+ if not os.path.exists(output_folder):
63
+ os.makedirs(output_folder)
64
+ st.session_state.output_folder = output_folder
65
+ st.session_state.mask_files = []
66
+ st.session_state.segmentation_done = False
67
+ st.session_state.selected_mask = None # Reset mask selection
68
+
69
+ st.image(st.session_state.input_image, caption="Uploaded Image", use_column_width=True)
70
+
71
+ # Run segmentation if not already done
72
+ if not st.session_state.segmentation_done:
73
+ if st.button("Run Segmentation"):
74
+ with st.spinner("Running segmentation..."):
75
+ run_segmentation(input_image_path, st.session_state.output_folder)
76
+ st.session_state.output_folder = os.path.join("tmp/output", input_image_name)
77
+ st.success(f"Segmentation completed. Masks saved in {st.session_state.output_folder}")
78
+
79
+ st.session_state.mask_files = [f for f in os.listdir(st.session_state.output_folder) if f.endswith('.png')]
80
+ st.session_state.segmentation_done = True
81
+
82
+ else:
83
+ input_image_name = os.path.splitext(uploaded_image.name)[0]
84
+ st.session_state.input_image = Image.open(f"tmp/{uploaded_image.name}")
85
+ st.session_state.output_folder = os.path.join("tmp/output", input_image_name)
86
+ st.success(f"Segmentation completed. Masks saved in {st.session_state.output_folder}")
87
+
88
+ st.session_state.mask_files = [f for f in os.listdir(st.session_state.output_folder) if f.endswith('.png')]
89
+ st.session_state.segmentation_done = True
90
+
91
+
92
+ # Display uploaded image
93
+ if st.session_state.input_image is not None:
94
+
95
+ # Only display dropdown and images if segmentation is done
96
+ if st.session_state.segmentation_done and st.session_state.mask_files:
97
+ # Dropdown to select a mask
98
+ selected_mask = st.selectbox("Select a mask to overlay", st.session_state.mask_files,
99
+ index=st.session_state.mask_files.index(st.session_state.selected_mask)
100
+ if st.session_state.selected_mask else 0)
101
+
102
+ # Save the selected mask in session state
103
+ st.session_state.selected_mask = selected_mask
104
+
105
+ # Load the selected mask
106
+ mask_image = Image.open(os.path.join(st.session_state.output_folder, selected_mask))
107
+
108
+ # Colorize the binary mask and add an outline
109
+ colorized_mask = colorize_and_outline_mask(mask_image)
110
+
111
+ # Overlay the selected mask on the input image
112
+ overlayed_image = overlay_mask_on_image(st.session_state.input_image, colorized_mask)
113
+
114
+ # Display the images side by side
115
+ col1, col2 = st.columns(2)
116
+
117
+ with col1:
118
+ st.image(st.session_state.input_image, caption="Original Image", use_column_width=True)
119
+
120
+ with col2:
121
+ st.image(overlayed_image, caption="Overlayed Image with Mask", use_column_width=True)
122
+
123
+ else:
124
+ st.info("Please upload an image to get started.")
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ cxas
weights/.cxas/UNet_ResNet50_default.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bdc485da991693860c18b759d63e7404cc6ab01b7a15c987c350212b581e0e2
3
+ size 881081391