Spaces:
Runtime error
Runtime error
Add gr state
Browse files
app.py
CHANGED
@@ -19,18 +19,34 @@ def mkstemp(suffix, dir=None):
|
|
19 |
return Path(path)
|
20 |
|
21 |
|
22 |
-
|
23 |
-
#
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
-
def get_masked_img(img, w, h):
|
29 |
point_coords = [w, h]
|
30 |
point_labels = [1]
|
31 |
dilate_kernel_size = 15
|
32 |
|
33 |
-
model['sam'].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# masks, _, _ = predictor.predict(
|
35 |
masks, _, _ = model['sam'].predict(
|
36 |
point_coords=np.array([point_coords]),
|
@@ -98,6 +114,12 @@ model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
|
98 |
|
99 |
|
100 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
with gr.Row():
|
102 |
img = gr.Image(label="Image")
|
103 |
# img_pointed = gr.Image(label='Pointed Image')
|
@@ -146,9 +168,11 @@ with gr.Blocks() as demo:
|
|
146 |
# []
|
147 |
# )
|
148 |
# img.change(get_sam_feat, [img], [])
|
|
|
|
|
149 |
sam_mask.click(
|
150 |
get_masked_img,
|
151 |
-
[img, w, h],
|
152 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
153 |
)
|
154 |
|
|
|
19 |
return Path(path)
|
20 |
|
21 |
|
22 |
+
def get_sam_feat(img):
|
23 |
+
# predictor.set_image(img)
|
24 |
+
model['sam'].set_image(img)
|
25 |
+
features = model['sam'].features
|
26 |
+
orig_h = model['sam'].orig_h
|
27 |
+
orig_w = model['sam'].orig_w
|
28 |
+
input_h = model['sam'].input_h
|
29 |
+
input_w = model['sam'].input_w
|
30 |
+
return features, orig_h, orig_w, input_h, input_w
|
31 |
|
32 |
|
33 |
+
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
34 |
point_coords = [w, h]
|
35 |
point_labels = [1]
|
36 |
dilate_kernel_size = 15
|
37 |
|
38 |
+
# model['sam'].is_image_set = False
|
39 |
+
model['sam'].features = features
|
40 |
+
model['sam'].orig_h = orig_h
|
41 |
+
model['sam'].orig_w = orig_w
|
42 |
+
model['sam'].input_h = input_h
|
43 |
+
model['sam'].input_w = input_w
|
44 |
+
# model['sam'].image_embedding = image_embedding
|
45 |
+
# model['sam'].original_size = original_size
|
46 |
+
# model['sam'].input_size = input_size
|
47 |
+
# model['sam'].is_image_set = True
|
48 |
+
|
49 |
+
# model['sam'].set_image(img)
|
50 |
# masks, _, _ = predictor.predict(
|
51 |
masks, _, _ = model['sam'].predict(
|
52 |
point_coords=np.array([point_coords]),
|
|
|
114 |
|
115 |
|
116 |
with gr.Blocks() as demo:
|
117 |
+
features = gr.State(None)
|
118 |
+
orig_h = gr.State(None)
|
119 |
+
orig_w = gr.State(None)
|
120 |
+
input_h = gr.State(None)
|
121 |
+
input_w = gr.State(None)
|
122 |
+
|
123 |
with gr.Row():
|
124 |
img = gr.Image(label="Image")
|
125 |
# img_pointed = gr.Image(label='Pointed Image')
|
|
|
168 |
# []
|
169 |
# )
|
170 |
# img.change(get_sam_feat, [img], [])
|
171 |
+
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
172 |
+
|
173 |
sam_mask.click(
|
174 |
get_masked_img,
|
175 |
+
[img, w, h, features, orig_h, orig_w, input_h, input_w],
|
176 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
177 |
)
|
178 |
|