Shilpaj commited on
Commit
ec6dea8
·
1 Parent(s): 00f0762

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -3
app.py CHANGED
@@ -7,7 +7,117 @@ import numpy as np
7
  import gradio as gr
8
  from PIL import ImageDraw
9
  from ultralytics import YOLO
10
- from utils.inference import segment_everything, segment_with_points, get_points_with_draw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  # Load the pre-trained model
@@ -236,8 +346,6 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
236
 
237
  segment_btn_t.click(segment_everything,
238
  inputs=[
239
- model,
240
- device,
241
  cond_img_t,
242
  input_size_slider_t,
243
  iou_threshold,
 
7
  import gradio as gr
8
  from PIL import ImageDraw
9
  from ultralytics import YOLO
10
+ from utils.tools_gradio import fast_process
11
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
12
+
13
+
14
+ def segment_everything(
15
+ input,
16
+ input_size=1024,
17
+ iou_threshold=0.7,
18
+ conf_threshold=0.25,
19
+ better_quality=False,
20
+ withContours=True,
21
+ use_retina=True,
22
+ text="",
23
+ wider=False,
24
+ mask_random_color=True,
25
+ ):
26
+ input_size = int(input_size)
27
+ w, h = input.size
28
+ scale = input_size / max(w, h)
29
+ new_w = int(w * scale)
30
+ new_h = int(h * scale)
31
+ input = input.resize((new_w, new_h))
32
+
33
+ results = model(input,
34
+ device=device,
35
+ retina_masks=True,
36
+ iou=iou_threshold,
37
+ conf=conf_threshold,
38
+ imgsz=input_size, )
39
+
40
+ if len(text) > 0:
41
+ results = format_results(results[0], 0)
42
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
43
+ annotations = np.array([annotations])
44
+ else:
45
+ annotations = results[0].masks.data
46
+
47
+ fig = fast_process(annotations=annotations,
48
+ image=input,
49
+ device=device,
50
+ scale=(1024 // input_size),
51
+ better_quality=better_quality,
52
+ mask_random_color=mask_random_color,
53
+ bbox=None,
54
+ use_retina=use_retina,
55
+ withContours=withContours, )
56
+ return fig
57
+
58
+
59
+ def segment_with_points(
60
+ input,
61
+ input_size=1024,
62
+ iou_threshold=0.7,
63
+ conf_threshold=0.25,
64
+ better_quality=False,
65
+ withContours=True,
66
+ use_retina=True,
67
+ mask_random_color=True,
68
+ ):
69
+ global global_points
70
+ global global_point_label
71
+
72
+ input_size = int(input_size)
73
+ w, h = input.size
74
+ scale = input_size / max(w, h)
75
+ new_w = int(w * scale)
76
+ new_h = int(h * scale)
77
+ input = input.resize((new_w, new_h))
78
+
79
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
80
+
81
+ results = model(input,
82
+ device=device,
83
+ retina_masks=True,
84
+ iou=iou_threshold,
85
+ conf=conf_threshold,
86
+ imgsz=input_size, )
87
+
88
+ results = format_results(results[0], 0)
89
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
90
+ annotations = np.array([annotations])
91
+
92
+ fig = fast_process(annotations=annotations,
93
+ image=input,
94
+ device=device,
95
+ scale=(1024 // input_size),
96
+ better_quality=better_quality,
97
+ mask_random_color=mask_random_color,
98
+ bbox=None,
99
+ use_retina=use_retina,
100
+ withContours=withContours, )
101
+
102
+ global_points = []
103
+ global_point_label = []
104
+ return fig, None
105
+
106
+
107
+ def get_points_with_draw(image, label, evt: gr.SelectData):
108
+ global global_points
109
+ global global_point_label
110
+
111
+ x, y = evt.index[0], evt.index[1]
112
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
113
+ global_points.append([x, y])
114
+ global_point_label.append(1 if label == 'Add Mask' else 0)
115
+
116
+ print(x, y, label == 'Add Mask')
117
+
118
+ draw = ImageDraw.Draw(image)
119
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
120
+ return image
121
 
122
 
123
  # Load the pre-trained model
 
346
 
347
  segment_btn_t.click(segment_everything,
348
  inputs=[
 
 
349
  cond_img_t,
350
  input_size_slider_t,
351
  iou_threshold,