merve HF staff commited on
Commit
b4e8f1d
1 Parent(s): 6086700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -4,17 +4,25 @@ import torch
4
  from PIL import Image
5
  from transformers import SamModel, SamProcessor
6
  from gradio_image_prompter import ImagePrompter
7
-
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
11
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
12
- slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
13
  slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
14
 
15
- def sam_box_inference(image, model, x_min, y_min, x_max, y_max):
 
 
 
 
 
 
16
 
17
- inputs = sam_processor(
 
 
18
  Image.fromarray(image),
19
  input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
20
  return_tensors="pt"
@@ -23,7 +31,7 @@ def sam_box_inference(image, model, x_min, y_min, x_max, y_max):
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
 
26
- mask = sam_processor.image_processor.post_process_masks(
27
  outputs.pred_masks.cpu(),
28
  inputs["original_sizes"].cpu(),
29
  inputs["reshaped_input_sizes"].cpu()
@@ -33,17 +41,20 @@ def sam_box_inference(image, model, x_min, y_min, x_max, y_max):
33
  print(mask.shape)
34
  return [(mask, "mask")]
35
 
 
 
36
 
37
- def sam_point_inference(image, model, x, y):
38
- inputs = sam_processor(
 
39
  image,
40
  input_points=[[[x, y]]],
41
  return_tensors="pt").to(device)
42
 
43
  with torch.no_grad():
44
- outputs = sam_model(**inputs)
45
 
46
- mask = sam_processor.post_process_masks(
47
  outputs.pred_masks.cpu(),
48
  inputs["original_sizes"].cpu(),
49
  inputs["reshaped_input_sizes"].cpu()
@@ -72,8 +83,8 @@ def infer_point(img):
72
  center_x = int(np.mean(nonzero_indices[1]))
73
  center_y = int(np.mean(nonzero_indices[0]))
74
  print("Point inference returned.")
75
- return ((image, sam_point_inference(image, slimsam_model, center_x, center_y)),
76
- (image, sam_point_inference(image, sam_model, center_x, center_y)))
77
 
78
  def infer_box(prompts):
79
  # background (original image) layers[0] ( point prompt) composite (total image)
@@ -86,8 +97,8 @@ def infer_box(prompts):
86
  print(points)
87
 
88
  # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
89
- return ((image, sam_box_inference(image, slimsam_model, points[0], points[1], points[3], points[4])),
90
- (image, sam_box_inference(image, sam_model, points[0], points[1], points[3], points[4])))
91
  with gr.Blocks(title="SlimSAM") as demo:
92
  gr.Markdown("# SlimSAM")
93
  gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
 
4
  from PIL import Image
5
  from transformers import SamModel, SamProcessor
6
  from gradio_image_prompter import ImagePrompter
7
+ import spaces
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
11
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
12
+ slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to("cuda")
13
  slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
14
 
15
+ def get_processor_and_model(slim: bool):
16
+ if slim:
17
+ return slimsam_processor, slimsam_model
18
+ return sam_processor, sam_model
19
+
20
+ @spaces.GPU
21
+ def sam_box_inference(image, x_min, y_min, x_max, y_max, *, slim=False):
22
 
23
+ processor, model = get_processor_and_model(slim)
24
+
25
+ inputs = processor(
26
  Image.fromarray(image),
27
  input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
28
  return_tensors="pt"
 
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
 
34
+ mask = processor.image_processor.post_process_masks(
35
  outputs.pred_masks.cpu(),
36
  inputs["original_sizes"].cpu(),
37
  inputs["reshaped_input_sizes"].cpu()
 
41
  print(mask.shape)
42
  return [(mask, "mask")]
43
 
44
+ @spaces.GPU
45
+ def sam_point_inference(image, x, y, *, slim=False):
46
 
47
+ processor, model = get_processor_and_model(slim)
48
+
49
+ inputs = processor(
50
  image,
51
  input_points=[[[x, y]]],
52
  return_tensors="pt").to(device)
53
 
54
  with torch.no_grad():
55
+ outputs = model(**inputs)
56
 
57
+ mask = processor.post_process_masks(
58
  outputs.pred_masks.cpu(),
59
  inputs["original_sizes"].cpu(),
60
  inputs["reshaped_input_sizes"].cpu()
 
83
  center_x = int(np.mean(nonzero_indices[1]))
84
  center_y = int(np.mean(nonzero_indices[0]))
85
  print("Point inference returned.")
86
+ return ((image, sam_point_inference(image, center_x, center_y, slim=True)),
87
+ (image, sam_point_inference(image, center_x, center_y)))
88
 
89
  def infer_box(prompts):
90
  # background (original image) layers[0] ( point prompt) composite (total image)
 
97
  print(points)
98
 
99
  # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
100
+ return ((image, sam_box_inference(image, points[0], points[1], points[3], points[4], slim=True)),
101
+ (image, sam_box_inference(image, points[0], points[1], points[3], points[4])))
102
  with gr.Blocks(title="SlimSAM") as demo:
103
  gr.Markdown("# SlimSAM")
104
  gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")