mattmdjaga commited on
Commit
cc7fbfd
1 Parent(s): 7dbfa30
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageDraw
5
+ import requests
6
+ from transformers import SamModel, SamProcessor
7
+ import cv2
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load model and processor
12
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
+
15
+ def mask_2_dots(mask):
16
+ gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
17
+ _, thresh = cv2.threshold(gray, 127, 255, 0)
18
+ kernel = np.ones((5,5),np.uint8)
19
+ closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
20
+ contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
21
+ points = []
22
+ for contour in contours:
23
+ moments = cv2.moments(contour)
24
+ cx = int(moments['m10']/moments['m00'])
25
+ cy = int(moments['m01']/moments['m00'])
26
+ points.append([cx, cy])
27
+ return [points]
28
+
29
+ def main_func(inputs):
30
+ dots = inputs['mask']
31
+ points = mask_2_dots(dots)
32
+
33
+ image_input = inputs['image']
34
+ image_input = Image.fromarray(image_input)
35
+
36
+ inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
37
+ # Forward pass
38
+ outputs = model(**inputs)
39
+
40
+ # Postprocess outputs
41
+ draw = ImageDraw.Draw(image_input)
42
+ for point in points[0]:
43
+ draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
44
+
45
+
46
+ masks = processor.image_processor.post_process_masks(
47
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
48
+ )
49
+ #scores = outputs.iou_scores
50
+
51
+ mask = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
52
+
53
+ pred_masks = [image_input]
54
+ for i in range(mask.shape[2]):
55
+ #mask[:,:,i] = mask[:,:,i] * scores[0][i].item()
56
+ pred_masks.append(Image.fromarray((mask[:,:,i] * 255).astype(np.uint8)))
57
+
58
+ return pred_masks
59
+
60
+
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("# Demo to run Segment Anything base model")
63
+ gr.Markdown("""This app uses the [Segment Anything](https://huggingface.co/facebook/sam-vit-base) model from Meta to get a mask from a points in an image.
64
+ Currently it only works for creating dots for one object. But, I'm planning to add extra features to make it work for multiple objects.
65
+ The output shows the image with the dots then the 3 predicted masks.
66
+ """)
67
+ with gr.Tab("Flip Image"):
68
+ with gr.Row():
69
+ image_input = gr.Image(tool='sketch')
70
+ image_output = gr.Gallery()
71
+
72
+ image_button = gr.Button("Segment Image")
73
+
74
+ image_button.click(main_func, inputs=image_input, outputs=image_output)