amaye15 commited on
Commit
78d359b
1 Parent(s): 933c40c

Version 1 - Working

Browse files
Files changed (4) hide show
  1. app.py +166 -138
  2. check.py +10 -0
  3. create_repo.py +9 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,183 +1,211 @@
1
- # import gradio as gr
2
- # from gradio_image_prompter import ImagePrompter
 
 
 
 
 
 
 
 
3
 
4
- # import os
5
- # import torch
 
6
 
7
 
8
- # def prompter(prompts):
9
- # image = prompts["image"] # Get the image from prompts
10
- # points = prompts["points"] # Get the points from prompts
11
 
12
- # # Print the collected inputs for debugging or logging
13
- # print("Image received:", image)
14
- # print("Points received:", points)
15
 
16
- # import torch
17
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
 
18
 
19
- # device = torch.device("cpu")
20
 
21
- # predictor = SAM2ImagePredictor.from_pretrained(
22
- # "facebook/sam2-hiera-base-plus", device=device
23
- # )
24
 
25
- # with torch.inference_mode():
26
- # predictor.set_image(image)
27
- # # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
28
- # input_point = [[point[0], point[1]] for point in points]
29
- # input_label = [1]
30
- # masks, _, _ = predictor.predict(
31
- # point_coords=input_point, point_labels=input_label
32
- # )
33
- # print("Predicted Mask:", masks)
34
 
35
- # return image, points
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
37
 
38
- # # Define the Gradio interface
39
- # demo = gr.Interface(
40
- # fn=prompter, # Use the custom prompter function
41
- # inputs=ImagePrompter(
42
- # show_label=False
43
- # ), # ImagePrompter for image input and point selection
44
- # outputs=[
45
- # gr.Image(show_label=False), # Display the image
46
- # gr.Dataframe(label="Points"), # Display the points in a DataFrame
47
- # ],
48
- # title="Image Point Collector",
49
- # description="Upload an image, click on it, and get the coordinates of the clicked points.",
50
- # )
51
 
52
- # # Launch the Gradio app
53
- # demo.launch()
54
 
 
 
55
 
56
- # import gradio as gr
57
- # from gradio_image_prompter import ImagePrompter
58
- # import torch
59
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
60
 
 
61
 
62
- # def prompter(prompts):
63
- # image = prompts["image"] # Get the image from prompts
64
- # points = prompts["points"] # Get the points from prompts
65
 
66
- # # Print the collected inputs for debugging or logging
67
- # print("Image received:", image)
68
- # print("Points received:", points)
69
 
70
- # device = torch.device("cpu")
 
 
 
 
 
 
 
 
 
71
 
72
- # # Load the SAM2ImagePredictor model
73
- # predictor = SAM2ImagePredictor.from_pretrained(
74
- # "facebook/sam2-hiera-base-plus", device=device
75
- # )
76
 
77
- # # Perform inference
78
- # with torch.inference_mode():
79
- # predictor.set_image(image)
80
- # input_point = [[point[0], point[1]] for point in points]
81
- # input_label = [1] * len(points) # Assuming all points are foreground
82
- # masks, _, _ = predictor.predict(
83
- # point_coords=input_point, point_labels=input_label
84
- # )
85
 
86
- # # The masks are returned as a list of numpy arrays
87
- # print("Predicted Mask:", masks)
88
 
89
- # # Assuming there's only one mask returned, you can adjust if there are multiple
90
- # predicted_mask = masks[0]
91
 
92
- # print(len(image))
 
93
 
94
- # print(len(predicted_mask))
 
95
 
96
- # # Create annotations for AnnotatedImage
97
- # annotations = [(predicted_mask, "Predicted Mask")]
98
 
99
- # return image, annotations
 
 
100
 
 
 
 
101
 
102
- # # Define the Gradio interface
103
- # demo = gr.Interface(
104
- # fn=prompter, # Use the custom prompter function
105
- # inputs=ImagePrompter(
106
- # show_label=False
107
- # ), # ImagePrompter for image input and point selection
108
- # outputs=gr.AnnotatedImage(), # Display the image with the predicted mask
109
- # title="Image Point Collector with Mask Overlay",
110
- # description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
111
- # )
112
 
113
- # # Launch the Gradio app
114
- # demo.launch()
 
 
 
 
 
 
115
 
 
116
 
117
- import gradio as gr
118
- from gradio_image_prompter import ImagePrompter
119
- import torch
120
- import numpy as np
121
- from sam2.sam2_image_predictor import SAM2ImagePredictor
122
- from PIL import Image
123
 
124
 
125
- def prompter(prompts):
126
- image = np.array(prompts["image"]) # Convert the image to a numpy array
127
- points = prompts["points"] # Get the points from prompts
128
 
129
- # Print the collected inputs for debugging or logging
130
- print("Image received:", image)
131
- print("Points received:", points)
 
 
 
 
 
 
132
 
133
- device = torch.device("cpu")
134
 
135
- # Load the SAM2ImagePredictor model
136
- predictor = SAM2ImagePredictor.from_pretrained(
137
- "facebook/sam2-hiera-base-plus", device=device
138
- )
 
 
 
 
139
 
140
- # Perform inference with multimask_output=True
141
- with torch.inference_mode():
142
- predictor.set_image(image)
143
- input_point = [[point[0], point[1]] for point in points]
144
- input_label = [1] * len(points) # Assuming all points are foreground
145
- masks, _, _ = predictor.predict(
146
- point_coords=input_point, point_labels=input_label, multimask_output=True
147
- )
148
 
149
- # Prepare individual images with separate overlays
150
- overlay_images = []
151
- for i, mask in enumerate(masks):
152
- print(f"Predicted Mask {i+1}:", mask)
153
- red_mask = np.zeros_like(image)
154
- red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
155
- red_mask = Image.fromarray(red_mask)
156
 
157
- # Convert the original image to a PIL image
158
- original_image = Image.fromarray(image)
159
 
160
- # Blend the original image with the red mask
161
- blended_image = Image.blend(original_image, red_mask, alpha=0.5)
 
 
 
162
 
163
- # Add the blended image to the list
164
- overlay_images.append(blended_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- return overlay_images
167
-
168
-
169
- # Define the Gradio interface
170
- demo = gr.Interface(
171
- fn=prompter, # Use the custom prompter function
172
- inputs=ImagePrompter(
173
- show_label=False
174
- ), # ImagePrompter for image input and point selection
175
- outputs=[
176
- gr.Image(show_label=False) for _ in range(3)
177
- ], # Display up to 3 overlay images
178
- title="Image Point Collector with Multiple Separate Mask Overlays",
179
- description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
180
- )
181
 
182
  # Launch the Gradio app
183
  demo.launch()
 
1
+ import gradio as gr
2
+ from gradio_image_prompter import ImagePrompter
3
+ import torch
4
+ import numpy as np
5
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
6
+ from PIL import Image
7
+ from uuid import uuid4
8
+ import os
9
+ from huggingface_hub import upload_folder
10
+ import shutil
11
 
12
+ MODEL = "facebook/sam2-hiera-large"
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
15
 
16
 
17
+ GLOBALS = {}
 
 
18
 
 
 
 
19
 
20
+ IMAGE = None
21
+ MASKS = None
22
+ INDEX = None
23
 
 
24
 
25
+ def prompter(prompts):
 
 
26
 
27
+ image = np.array(prompts["image"]) # Convert the image to a numpy array
28
+ points = prompts["points"] # Get the points from prompts
 
 
 
 
 
 
 
29
 
30
+ # Perform inference with multimask_output=True
31
+ with torch.inference_mode():
32
+ PREDICTOR.set_image(image)
33
+ input_point = [[point[0], point[1]] for point in points]
34
+ input_label = [1] * len(points) # Assuming all points are foreground
35
+ masks, _, _ = PREDICTOR.predict(
36
+ point_coords=input_point, point_labels=input_label, multimask_output=True
37
+ )
38
 
39
+ # Prepare individual images with separate overlays
40
+ overlay_images = []
41
+ for i, mask in enumerate(masks):
42
+ print(f"Predicted Mask {i+1}:", mask.shape)
43
+ red_mask = np.zeros_like(image)
44
+ red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
45
+ red_mask = Image.fromarray(red_mask)
46
 
47
+ # Convert the original image to a PIL image
48
+ original_image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Blend the original image with the red mask
51
+ blended_image = Image.blend(original_image, red_mask, alpha=0.5)
52
 
53
+ # Add the blended image to the list
54
+ overlay_images.append(blended_image)
55
 
56
+ global IMAGE, MASKS
 
 
 
57
 
58
+ IMAGE, MASKS = image, masks
59
 
60
+ return overlay_images[0], overlay_images[1], overlay_images[2], masks
 
 
61
 
 
 
 
62
 
63
+ def select_mask(
64
+ selected_mask_index,
65
+ mask1,
66
+ mask2,
67
+ mask3,
68
+ ):
69
+ masks = [mask1, mask2, mask3]
70
+ global INDEX
71
+ INDEX = selected_mask_index
72
+ return masks[selected_mask_index]
73
 
 
 
 
 
74
 
75
+ def save_selected_mask(image, mask, output_dir="output"):
 
 
 
 
 
 
 
76
 
77
+ output_dir = os.path.join(os.getcwd(), output_dir)
 
78
 
79
+ os.makedirs(output_dir, exist_ok=True)
 
80
 
81
+ # Generate a unique UUID for the folder name
82
+ folder_id = str(uuid4())
83
 
84
+ # Create a path for the new folder
85
+ folder_path = os.path.join(output_dir, folder_id)
86
 
87
+ # Ensure the folder is created
88
+ os.makedirs(folder_path, exist_ok=True)
89
 
90
+ # Define the paths for saving the image and mask
91
+ image_path = os.path.join(folder_path, "image.npy")
92
+ mask_path = os.path.join(folder_path, "mask.npy")
93
 
94
+ # Save the image and mask to the respective paths
95
+ with open(image_path, "wb") as f:
96
+ np.save(f, IMAGE)
97
 
98
+ with open(mask_path, "wb") as f:
99
+ np.save(f, MASKS[INDEX])
 
 
 
 
 
 
 
 
100
 
101
+ # Upload the folder to the Hugging Face Hub
102
+ upload_folder(
103
+ folder_path=output_dir,
104
+ # path_in_repo=path_in_repo,
105
+ repo_id="amaye15/object-segmentation",
106
+ repo_type="dataset",
107
+ # ignore_patterns="**/logs/*.txt", # Adjust this if needed
108
+ )
109
 
110
+ shutil.rmtree(folder_path)
111
 
112
+ return f"Image and mask saved to {folder_path}."
 
 
 
 
 
113
 
114
 
115
+ def save_dataset_name(key, dataset_name):
116
+ global GLOBALS
117
+ GLOBALS[key] = dataset_name
118
 
119
+ iframe_code = f"""
120
+ <iframe
121
+ src="https://huggingface.co/datasets/{dataset_name}/embed/viewer/default/train"
122
+ frameborder="0"
123
+ width="100%"
124
+ height="560px"
125
+ ></iframe>
126
+ """
127
+ return f"Huggingface Dataset: {dataset_name}", iframe_code
128
 
 
129
 
130
+ # Define the Gradio Blocks app
131
+ with gr.Blocks() as demo:
132
+ with gr.Tab("Setup"):
133
+ with gr.Row():
134
+ with gr.Column():
135
+ source = gr.Textbox(label="Source Dataset")
136
+ source_display = gr.Markdown()
137
+ iframe_display = gr.HTML()
138
 
139
+ source.change(
140
+ save_dataset_name,
141
+ inputs=(gr.State("source_dataset"), source),
142
+ outputs=(source_display, iframe_display),
143
+ )
 
 
 
144
 
145
+ with gr.Column():
 
 
 
 
 
 
146
 
147
+ destination = gr.Textbox(label="Destination Dataset")
148
+ destination_display = gr.Markdown()
149
 
150
+ destination.change(
151
+ save_dataset_name,
152
+ inputs=(gr.State("destination_dataset"), destination),
153
+ outputs=destination_display,
154
+ )
155
 
156
+ with gr.Tab("Object Mask - Point Prompt"):
157
+ gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays")
158
+ gr.Markdown(
159
+ "Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images."
160
+ )
161
+
162
+ with gr.Row():
163
+ with gr.Column():
164
+ # Input: ImagePrompter
165
+ image_input = ImagePrompter(show_label=False)
166
+ submit_button = gr.Button("Submit")
167
+ with gr.Row():
168
+ with gr.Column():
169
+ # Outputs: Up to 3 overlay images
170
+ image_output_1 = gr.Image(show_label=False)
171
+ with gr.Column():
172
+ image_output_2 = gr.Image(show_label=False)
173
+ with gr.Column():
174
+ image_output_3 = gr.Image(show_label=False)
175
+
176
+ # Dropdown for selecting the correct mask
177
+ with gr.Row():
178
+ mask_selector = gr.Radio(
179
+ label="Select the correct mask",
180
+ choices=["Mask 1", "Mask 2", "Mask 3"],
181
+ type="index",
182
+ )
183
+ # selected_mask_output = gr.Image(show_label=False)
184
+
185
+ save_button = gr.Button("Save Selected Mask and Image")
186
+ save_message = gr.Textbox(visible=False)
187
+
188
+ # Define the action triggered by the submit button
189
+ submit_button.click(
190
+ fn=prompter,
191
+ inputs=image_input,
192
+ outputs=[image_output_1, image_output_2, image_output_3, gr.State()],
193
+ )
194
 
195
+ # Define the action triggered by mask selection
196
+ mask_selector.change(
197
+ fn=select_mask,
198
+ inputs=[mask_selector, image_output_1, image_output_2, image_output_3],
199
+ outputs=gr.State(),
200
+ )
201
+
202
+ # Define the action triggered by the save button
203
+ save_button.click(
204
+ fn=save_selected_mask,
205
+ inputs=[gr.State(), gr.State()],
206
+ outputs=save_message,
207
+ show_progress=True,
208
+ )
 
209
 
210
  # Launch the Gradio app
211
  demo.launch()
check.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ # Load the image data from the .npy file
5
+ image = np.load("/Users/andrewmayes/Dev/image/image.npy")
6
+
7
+ # Display the image using matplotlib
8
+ plt.imshow(image)
9
+ plt.axis("off") # Turn off the axis labels
10
+ plt.show() # Show the image
create_repo.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+
3
+ # Initialize the API
4
+ api = HfApi()
5
+
6
+ # Create a new dataset repository
7
+ repo_url = api.create_repo(repo_id="amaye15/object-segmentation", repo_type="dataset")
8
+
9
+ print(f"Dataset repository created: {repo_url}")
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  gradio
2
  gradio-image-prompter
 
3
  Pillow
4
  opencv-python
5
  git+https://github.com/facebookresearch/segment-anything-2.git
 
1
  gradio
2
  gradio-image-prompter
3
+ huggingface-hub
4
  Pillow
5
  opencv-python
6
  git+https://github.com/facebookresearch/segment-anything-2.git