Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+ from segment_anything import sam_model_registry, SamPredictor
7
+
8
+ # Load model
9
+ checkpoint = "sam_vit_h_4b8939.pth"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model_type = "vit_h"
12
+
13
+ sam = sam_model_registry[model_type](checkpoint=checkpoint)
14
+ sam.to(device)
15
+ predictor = SamPredictor(sam)
16
+
17
+ def segment_image(input_img):
18
+ np_img = np.array(input_img)
19
+ image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
20
+
21
+ predictor.set_image(image)
22
+ h, w, _ = image.shape
23
+ input_point = np.array([[w // 2, h // 2]])
24
+ input_label = np.array([1])
25
+
26
+ masks, scores, logits = predictor.predict(
27
+ point_coords=input_point,
28
+ point_labels=input_label,
29
+ multimask_output=False
30
+ )
31
+
32
+ mask = masks[0].astype(np.uint8) * 255
33
+ return Image.fromarray(mask)
34
+
35
+ # UI
36
+ iface = gr.Interface(fn=segment_image,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs=gr.Image(type="pil"),
39
+ title="Segment Anything Model",
40
+ description="Upload an image and get a segmentation mask.")
41
+
42
+ iface.launch()