SauravMaheshkar commited on
Commit
bf29adc
·
unverified ·
1 Parent(s): 31bdcc8

feat: initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. app.py +92 -0
  3. assets/img.png +3 -0
  4. requirements.txt +3 -0
  5. segment-anything-2 +1 -0
  6. src/__init__.py +0 -0
  7. src/plot_utils.py +90 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+
5
+ from PIL import Image
6
+
7
+ from src.plot_utils import show_masks
8
+ from gradio_image_annotation import image_annotator
9
+
10
+
11
+ from sam2.build_sam import build_sam2
12
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
13
+
14
+ choice_mapping = {
15
+ "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
16
+ "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
17
+ "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
18
+ "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
19
+ }
20
+
21
+
22
+ def predict(model_choice: str, annotations, image):
23
+ config_file, ckpt_path = choice_mapping[str(model_choice)]
24
+ sam2_model = build_sam2(config_file, ckpt_path, device="cpu")
25
+ predictor = SAM2ImagePredictor(sam2_model)
26
+ predictor.set_image(image)
27
+ coordinates = np.array(
28
+ [
29
+ int(annotations["boxes"][0]["xmin"]),
30
+ int(annotations["boxes"][0]["ymin"]),
31
+ int(annotations["boxes"][0]["xmax"]),
32
+ int(annotations["boxes"][0]["ymax"]),
33
+ ]
34
+ )
35
+ masks, scores, _ = predictor.predict(
36
+ point_coords=None,
37
+ point_labels=None,
38
+ box=coordinates[None, :],
39
+ multimask_output=False,
40
+ )
41
+ mask = masks.transpose(1, 2, 0)
42
+ mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
43
+ cv2.imwrite("mask.png", mask_image)
44
+
45
+ return [
46
+ show_masks(image, masks, scores, box_coords=coordinates),
47
+ gr.DownloadButton("Download Mask", value="mask.png", visible=True),
48
+ ]
49
+
50
+
51
+ with gr.Blocks(delete_cache=(30, 30)) as demo:
52
+ gr.Markdown(
53
+ """
54
+ # 1. Choose Model Checkpoint
55
+ """
56
+ )
57
+ with gr.Row():
58
+ model = gr.Dropdown(
59
+ choices=["tiny", "small", "base_plus", "large"],
60
+ value="tiny",
61
+ label="Model Checkpoint",
62
+ info="Which model checkpoint to load?",
63
+ )
64
+
65
+ gr.Markdown(
66
+ """
67
+ # 2. Upload an Image
68
+ """
69
+ )
70
+
71
+ with gr.Row():
72
+ img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image")
73
+
74
+ gr.Markdown(
75
+ """
76
+ # 3. Draw Bounding Box
77
+ """
78
+ )
79
+
80
+ annotator = image_annotator(
81
+ value={"image": img.value["path"]},
82
+ disable_edit_boxes=True,
83
+ single_box=True,
84
+ label="Draw a bounding box",
85
+ )
86
+ btn = gr.Button("Get Segmentation Mask")
87
+ download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
88
+ btn.click(
89
+ fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn]
90
+ )
91
+
92
+ demo.launch()
assets/img.png ADDED

Git LFS Details

  • SHA256: 1e2f1e4675eab9280d901413df5216283a8771ddfe07d626ad8968d8ab3ea1de
  • Pointer size: 132 Bytes
  • Size of remote file: 2.71 MB
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ gradio_image_annotation
3
+ -e segment-anything-2/
segment-anything-2 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 82b026cd5578af78757323ab99a0b5c8dc456cff
src/__init__.py ADDED
File without changes
src/plot_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ def show_mask(mask, ax, random_color=False, borders=True):
6
+ if random_color:
7
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
8
+ else:
9
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
10
+ h, w = mask.shape[-2:]
11
+ mask = mask.astype(np.uint8)
12
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
13
+ if borders:
14
+ import cv2
15
+
16
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
17
+ # Try to smooth contours
18
+ contours = [
19
+ cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
20
+ ]
21
+ mask_image = cv2.drawContours(
22
+ mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
23
+ )
24
+ ax.imshow(mask_image)
25
+
26
+
27
+ def show_points(coords, labels, ax, marker_size=375):
28
+ pos_points = coords[labels == 1]
29
+ neg_points = coords[labels == 0]
30
+ ax.scatter(
31
+ pos_points[:, 0],
32
+ pos_points[:, 1],
33
+ color="green",
34
+ marker="*",
35
+ s=marker_size,
36
+ edgecolor="white",
37
+ linewidth=1.25,
38
+ )
39
+ ax.scatter(
40
+ neg_points[:, 0],
41
+ neg_points[:, 1],
42
+ color="red",
43
+ marker="*",
44
+ s=marker_size,
45
+ edgecolor="white",
46
+ linewidth=1.25,
47
+ )
48
+
49
+
50
+ def show_box(box, ax):
51
+ x0, y0 = box[0], box[1]
52
+ w, h = box[2] - box[0], box[3] - box[1]
53
+ ax.add_patch(
54
+ plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
55
+ )
56
+
57
+
58
+ def show_masks(
59
+ image,
60
+ masks,
61
+ scores,
62
+ point_coords=None,
63
+ box_coords=None,
64
+ input_labels=None,
65
+ borders=True,
66
+ ):
67
+ num_masks = len(masks)
68
+ num_cols = num_masks # Number of columns is equal to the number of masks
69
+
70
+ fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5))
71
+
72
+ if num_masks == 1:
73
+ axes = [axes] # Ensure axes is iterable when there's only one mask
74
+
75
+ for i, (mask, score) in enumerate(zip(masks, scores)):
76
+ ax = axes[i]
77
+
78
+ ax.imshow(image)
79
+ show_mask(mask, ax, borders=borders)
80
+ if point_coords is not None:
81
+ assert input_labels is not None
82
+ show_points(point_coords, input_labels, ax)
83
+ if box_coords is not None:
84
+ show_box(box_coords, ax)
85
+ if len(scores) > 1:
86
+ ax.set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
87
+ ax.axis("off")
88
+
89
+ plt.tight_layout()
90
+ return plt