Schrodingers commited on
Commit
bae6c47
1 Parent(s): 72dfc75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -81
app.py CHANGED
@@ -1,85 +1,37 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from flask import Flask, render_template
5
 
6
- # A placeholder function; Actual LiveWire code would be much more involved
7
- #def live_wire_segmentation(image, points):
8
- # This is a placeholder. Implement actual live wire segmentation here.
9
- # return image
10
-
11
-
12
- import cv2
13
- import numpy as np
14
-
15
- def compute_cost(image):
16
- gradient_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
17
- gradient_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
18
- gradient_magnitude = cv2.magnitude(gradient_x, gradient_y)
19
- return gradient_magnitude
20
-
21
- def dijkstra(cost, start, end):
22
- # Dimensions of the image
23
- h, w = cost.shape
24
- visited = np.zeros((h, w), dtype=np.bool_)
25
- distance = np.full((h, w), np.inf)
26
- parent = np.zeros((h, w, 2), dtype=np.int16) # to store the path
27
-
28
- distance[start[1], start[0]] = 0
29
- for _ in range(h * w):
30
- min_distance = np.inf
31
- for y in range(h):
32
- for x in range(w):
33
- if not visited[y, x] and distance[y, x] < min_distance:
34
- u = (x, y)
35
- min_distance = distance[y, x]
36
-
37
- visited[u[1], u[0]] = True
38
-
39
- # Check neighbors
40
- for i in [-1, 0, 1]:
41
- for j in [-1, 0, 1]:
42
- if 0 <= u[1] + i < h and 0 <= u[0] + j < w:
43
- v = (u[0] + j, u[1] + i)
44
- alt = distance[u[1], u[0]] + cost[v[1], v[0]]
45
- if alt < distance[v[1], v[0]]:
46
- distance[v[1], v[0]] = alt
47
- parent[v[1], v[0]] = u
48
-
49
- # Reconstruct path from end to start by following parents
50
- path = []
51
- while end != start:
52
- path.append(end)
53
- end = tuple(parent[end[1], end[0]])
54
- path.append(start)
55
-
56
- return path
57
-
58
- def live_wire_segmentation(image, start, end):
59
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
60
- cost = compute_cost(gray)
61
- path = dijkstra(cost, start, end)
62
-
63
- for point in path:
64
- cv2.circle(image, point, 1, (0, 255, 0), -1) # Draw path on the image
65
-
66
- return image
67
-
68
- def main_app():
69
- interface = gr.Interface(
70
- fn=live_wire_segmentation,
71
- inputs=["image", gr.inputs.Sketchpad()], # You may have to adjust this for latest version of Gradio
72
- outputs="image",
73
- live=True
74
- )
75
- interface.launch()
76
-
77
- app = Flask(__name__)
78
-
79
- @app.route('/')
80
- def index():
81
- main_app()
82
- return "Gradio App Running!"
83
-
84
- if __name__ == "__main__":
85
- app.run(debug=True)
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
 
4
 
5
+ # Function to perform image segmentation using OpenCV Watershed Algorithm
6
+ def watershed_segmentation(input_image, scribble_image):
7
+ # Load the input image and scribble image
8
+ image = cv2.cvtColor(input_image.astype('uint8'), cv2.COLOR_RGBA2BGR)
9
+ scribble = cv2.cvtColor(scribble_image.astype('uint8'), cv2.COLOR_RGBA2GRAY)
10
+
11
+ # Convert scribble to markers (0 for background, 1 for unknown, 2 for foreground)
12
+ markers = np.zeros_like(scribble)
13
+ markers[scribble == 0] = 0
14
+ markers[scribble == 255] = 1
15
+ markers[scribble == 128] = 2
16
+
17
+ # Apply watershed algorithm
18
+ cv2.watershed(image, markers)
19
+
20
+ # Create a segmented mask
21
+ segmented_mask = np.zeros_like(image, dtype=np.uint8)
22
+ segmented_mask[markers == 2] = [0, 0, 255] # Red color for segmented regions
23
+
24
+ return segmented_mask
25
+
26
+ # Gradio interface
27
+ input_image = gr.inputs.Image(type='pil', label="Upload an image")
28
+ scribble_image = gr.inputs.Image(type='pil', label="Scribble on the image")
29
+ output_image = gr.outputs.Image(type='pil', label="Segmented Image")
30
+
31
+ gr.Interface(
32
+ fn=watershed_segmentation,
33
+ inputs=[input_image, scribble_image],
34
+ outputs=output_image,
35
+ title="Image Segmentation using Watershed Algorithm",
36
+ description="Upload an image and scribble on it to perform segmentation using the Watershed Algorithm.",
37
+ ).launch()