Rishie Nandhan commited on
Commit
0345fe3
·
1 Parent(s): dc397b4

Using bounding box prompts

Browse files
Files changed (2) hide show
  1. __pycache__/app.cpython-38.pyc +0 -0
  2. app.py +23 -9
__pycache__/app.cpython-38.pyc CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
 
app.py CHANGED
@@ -21,7 +21,19 @@ sidewalk_model.load_state_dict(checkpoint["model"])
21
  device = "cpu"
22
  sidewalk_model.to(device)
23
  # print('Status Update: Using GPU.')
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def segment_sidewalk(image):
27
  test_image = Image.fromarray(image).convert("RGB")
@@ -29,17 +41,19 @@ def segment_sidewalk(image):
29
  # Keep a copy of original image for display
30
  original_image = test_image.copy()
31
 
32
- # Create grid of points for prompting
33
- array_size = 256
34
- grid_size = 7
35
- x = np.linspace(0, array_size - 1, grid_size)
36
- y = np.linspace(0, array_size - 1, grid_size)
37
- xv, yv = np.meshgrid(x, y)
38
- input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
39
- input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
40
 
 
 
41
  # prepare image for the model
42
- inputs = processor(test_image, input_points=input_points, return_tensors="pt")
43
  # Convert dtype to float32 as the MPS framework doesn't support float64
44
  inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
45
  sidewalk_model.eval()
 
21
  device = "cpu"
22
  sidewalk_model.to(device)
23
  # print('Status Update: Using GPU.')
24
+ print('Status Update: FindMySidewalk Ready for inference ...')
25
 
26
+ # Generate bounding box prompt for SAM
27
+ def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256):
28
+ # add perturbation if inputted bounding box coordinates
29
+ x_min = max(0, x_min - np.random.randint(0, 20))
30
+ x_max = min(W, x_max + np.random.randint(0, 20))
31
+ y_min = max(0, y_min - np.random.randint(0, 20))
32
+ y_max = min(H, y_max + np.random.randint(0, 20))
33
+
34
+ bbox = [x_min, y_min, x_max, y_max]
35
+
36
+ return bbox
37
 
38
  def segment_sidewalk(image):
39
  test_image = Image.fromarray(image).convert("RGB")
 
41
  # Keep a copy of original image for display
42
  original_image = test_image.copy()
43
 
44
+ # # Create grid of points for prompting
45
+ # array_size = 256
46
+ # grid_size = 7
47
+ # x = np.linspace(0, array_size - 1, grid_size)
48
+ # y = np.linspace(0, array_size - 1, grid_size)
49
+ # xv, yv = np.meshgrid(x, y)
50
+ # input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
51
+ # input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
52
 
53
+ # obtain bounding box prompt over entire image
54
+ prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1])
55
  # prepare image for the model
56
+ inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
57
  # Convert dtype to float32 as the MPS framework doesn't support float64
58
  inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
59
  sidewalk_model.eval()