Spaces:
Running
Running
commit
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from gradio_client import Client, handle_file
|
|
4 |
from pathlib import Path
|
5 |
from gradio.utils import get_cache_folder
|
6 |
|
|
|
7 |
|
8 |
class Examples(gr.helpers.Examples):
|
9 |
def __init__(self, *args, cached_folder=None, **kwargs):
|
@@ -21,37 +22,104 @@ client = Client("Canyu/Diception",
|
|
21 |
hf_token=HF_TOKEN)
|
22 |
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
if path_input is None:
|
26 |
raise gr.Error(
|
27 |
"Missing image in the left pane: please upload an image first."
|
28 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
return client.predict(
|
32 |
-
|
33 |
-
api_name="/
|
34 |
)
|
35 |
|
36 |
def clear_cache():
|
37 |
return None, None
|
38 |
|
39 |
def run_demo_server():
|
|
|
40 |
gradio_theme = gr.themes.Default()
|
41 |
with gr.Blocks(
|
42 |
theme=gradio_theme,
|
43 |
title="Matting",
|
44 |
) as demo:
|
45 |
with gr.Row():
|
46 |
-
gr.Markdown("#
|
47 |
with gr.Row():
|
48 |
-
gr.Markdown("###
|
|
|
|
|
|
|
49 |
with gr.Row():
|
50 |
with gr.Column():
|
51 |
matting_image_input = gr.Image(
|
52 |
label="Input Image",
|
53 |
type="filepath",
|
54 |
)
|
|
|
55 |
with gr.Row():
|
56 |
matting_image_submit_btn = gr.Button(
|
57 |
value="Estimate Matting", variant="primary"
|
@@ -80,7 +148,7 @@ def run_demo_server():
|
|
80 |
|
81 |
matting_image_submit_btn.click(
|
82 |
fn=process_image_check,
|
83 |
-
inputs=matting_image_input,
|
84 |
outputs=None,
|
85 |
preprocess=False,
|
86 |
queue=False,
|
|
|
4 |
from pathlib import Path
|
5 |
from gradio.utils import get_cache_folder
|
6 |
|
7 |
+
from PIL import Image
|
8 |
|
9 |
class Examples(gr.helpers.Examples):
|
10 |
def __init__(self, *args, cached_folder=None, **kwargs):
|
|
|
22 |
hf_token=HF_TOKEN)
|
23 |
|
24 |
|
25 |
+
map_prompt = {
|
26 |
+
'depth': '[[image2depth]]',
|
27 |
+
'normal': '[[image2normal]]',
|
28 |
+
'pose': '[[image2pose]]',
|
29 |
+
'entity segmentation': '[[image2panoptic coarse]]',
|
30 |
+
'point segmentation': '[[image2segmentation]]',
|
31 |
+
'semantic segmentation': '[[image2semantic]]',
|
32 |
+
}
|
33 |
+
|
34 |
+
def download_additional_params(model_name, filename="add_params.bin"):
|
35 |
+
# 下载文件并返回文件路径
|
36 |
+
file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN)
|
37 |
+
return file_path
|
38 |
+
|
39 |
+
# 加载 additional_params.bin 文件
|
40 |
+
def load_additional_params(model_name):
|
41 |
+
# 下载 additional_params.bin
|
42 |
+
params_path = download_additional_params(model_name)
|
43 |
+
|
44 |
+
# 使用 torch.load() 加载文件内容
|
45 |
+
additional_params = torch.load(params_path, map_location='cpu')
|
46 |
+
|
47 |
+
# 返回加载的参数内容
|
48 |
+
return additional_params
|
49 |
+
|
50 |
+
def process_image_check(path_input, prompt):
|
51 |
if path_input is None:
|
52 |
raise gr.Error(
|
53 |
"Missing image in the left pane: please upload an image first."
|
54 |
)
|
55 |
+
if len(prompt) == 0:
|
56 |
+
raise gr.Error(
|
57 |
+
"At least 1 prediction type is needed."
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def process_image_4(image_path, prompt):
|
63 |
+
|
64 |
+
inputs = []
|
65 |
+
for p in prompt:
|
66 |
+
image = Image.open(image_path)
|
67 |
+
|
68 |
+
w, h = image.size
|
69 |
+
|
70 |
+
coor_point = torch.zeros((1,5,2)).to(torch.float32)
|
71 |
+
point_labels = torch.zeros((1,5,1)).to(torch.float32)
|
72 |
|
73 |
+
image = image.resize((768, 768), Image.LANCZOS).convert('RGB')
|
74 |
+
to_tensor = transforms.ToTensor()
|
75 |
+
image = (to_tensor(image) - 0.5) * 2
|
76 |
+
|
77 |
+
cur_input = {
|
78 |
+
'input_images': image.unsqueeze(0),
|
79 |
+
'original_size': torch.tensor([[w,h]]),
|
80 |
+
'target_size': torch.tensor([[768, 768]]),
|
81 |
+
'prompt': [p],
|
82 |
+
'coor_point': coor_point,
|
83 |
+
'point_labels': point_labels,
|
84 |
+
'generator': generator
|
85 |
+
}
|
86 |
+
inputs.append(cur_input)
|
87 |
+
|
88 |
+
return inputs
|
89 |
+
|
90 |
+
|
91 |
+
def infer_image_matting(image_path, prompt):
|
92 |
+
inputs = process_image_4(image_path, prompt)
|
93 |
+
return None
|
94 |
return client.predict(
|
95 |
+
batch=inputs,
|
96 |
+
api_name="/inf"
|
97 |
)
|
98 |
|
99 |
def clear_cache():
|
100 |
return None, None
|
101 |
|
102 |
def run_demo_server():
|
103 |
+
options = ['depth', 'normal', 'entity', 'pose']
|
104 |
gradio_theme = gr.themes.Default()
|
105 |
with gr.Blocks(
|
106 |
theme=gradio_theme,
|
107 |
title="Matting",
|
108 |
) as demo:
|
109 |
with gr.Row():
|
110 |
+
gr.Markdown("# Diception Demo")
|
111 |
with gr.Row():
|
112 |
+
gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
|
113 |
+
with gr.Row():
|
114 |
+
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
|
115 |
+
|
116 |
with gr.Row():
|
117 |
with gr.Column():
|
118 |
matting_image_input = gr.Image(
|
119 |
label="Input Image",
|
120 |
type="filepath",
|
121 |
)
|
122 |
+
|
123 |
with gr.Row():
|
124 |
matting_image_submit_btn = gr.Button(
|
125 |
value="Estimate Matting", variant="primary"
|
|
|
148 |
|
149 |
matting_image_submit_btn.click(
|
150 |
fn=process_image_check,
|
151 |
+
inputs=[matting_image_input, checkbox_group],
|
152 |
outputs=None,
|
153 |
preprocess=False,
|
154 |
queue=False,
|