Spaces:
Runtime error
Runtime error
Rishie Nandhan
commited on
Commit
·
0345fe3
1
Parent(s):
dc397b4
Using bounding box prompts
Browse files- __pycache__/app.cpython-38.pyc +0 -0
- app.py +23 -9
__pycache__/app.cpython-38.pyc
CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -21,7 +21,19 @@ sidewalk_model.load_state_dict(checkpoint["model"])
|
|
21 |
device = "cpu"
|
22 |
sidewalk_model.to(device)
|
23 |
# print('Status Update: Using GPU.')
|
|
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def segment_sidewalk(image):
|
27 |
test_image = Image.fromarray(image).convert("RGB")
|
@@ -29,17 +41,19 @@ def segment_sidewalk(image):
|
|
29 |
# Keep a copy of original image for display
|
30 |
original_image = test_image.copy()
|
31 |
|
32 |
-
# Create grid of points for prompting
|
33 |
-
array_size = 256
|
34 |
-
grid_size = 7
|
35 |
-
x = np.linspace(0, array_size - 1, grid_size)
|
36 |
-
y = np.linspace(0, array_size - 1, grid_size)
|
37 |
-
xv, yv = np.meshgrid(x, y)
|
38 |
-
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
|
39 |
-
input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
|
40 |
|
|
|
|
|
41 |
# prepare image for the model
|
42 |
-
inputs = processor(test_image,
|
43 |
# Convert dtype to float32 as the MPS framework doesn't support float64
|
44 |
inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
|
45 |
sidewalk_model.eval()
|
|
|
21 |
device = "cpu"
|
22 |
sidewalk_model.to(device)
|
23 |
# print('Status Update: Using GPU.')
|
24 |
+
print('Status Update: FindMySidewalk Ready for inference ...')
|
25 |
|
26 |
+
# Generate bounding box prompt for SAM
|
27 |
+
def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256):
|
28 |
+
# add perturbation if inputted bounding box coordinates
|
29 |
+
x_min = max(0, x_min - np.random.randint(0, 20))
|
30 |
+
x_max = min(W, x_max + np.random.randint(0, 20))
|
31 |
+
y_min = max(0, y_min - np.random.randint(0, 20))
|
32 |
+
y_max = min(H, y_max + np.random.randint(0, 20))
|
33 |
+
|
34 |
+
bbox = [x_min, y_min, x_max, y_max]
|
35 |
+
|
36 |
+
return bbox
|
37 |
|
38 |
def segment_sidewalk(image):
|
39 |
test_image = Image.fromarray(image).convert("RGB")
|
|
|
41 |
# Keep a copy of original image for display
|
42 |
original_image = test_image.copy()
|
43 |
|
44 |
+
# # Create grid of points for prompting
|
45 |
+
# array_size = 256
|
46 |
+
# grid_size = 7
|
47 |
+
# x = np.linspace(0, array_size - 1, grid_size)
|
48 |
+
# y = np.linspace(0, array_size - 1, grid_size)
|
49 |
+
# xv, yv = np.meshgrid(x, y)
|
50 |
+
# input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
|
51 |
+
# input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
|
52 |
|
53 |
+
# obtain bounding box prompt over entire image
|
54 |
+
prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1])
|
55 |
# prepare image for the model
|
56 |
+
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
|
57 |
# Convert dtype to float32 as the MPS framework doesn't support float64
|
58 |
inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
|
59 |
sidewalk_model.eval()
|