Canyu commited on
Commit
19fc903
·
1 Parent(s): f3ad944
Files changed (1) hide show
  1. app.py +75 -7
app.py CHANGED
@@ -4,6 +4,7 @@ from gradio_client import Client, handle_file
4
  from pathlib import Path
5
  from gradio.utils import get_cache_folder
6
 
 
7
 
8
  class Examples(gr.helpers.Examples):
9
  def __init__(self, *args, cached_folder=None, **kwargs):
@@ -21,37 +22,104 @@ client = Client("Canyu/Diception",
21
  hf_token=HF_TOKEN)
22
 
23
 
24
- def process_image_check(path_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if path_input is None:
26
  raise gr.Error(
27
  "Missing image in the left pane: please upload an image first."
28
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def infer_image_matting(matting_image_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return client.predict(
32
- prompt=handle_file(matting_image_input),
33
- api_name="/infer_image_matting"
34
  )
35
 
36
  def clear_cache():
37
  return None, None
38
 
39
  def run_demo_server():
 
40
  gradio_theme = gr.themes.Default()
41
  with gr.Blocks(
42
  theme=gradio_theme,
43
  title="Matting",
44
  ) as demo:
45
  with gr.Row():
46
- gr.Markdown("# Matting Demo")
47
  with gr.Row():
48
- gr.Markdown("### Due to the GPU quota limit, if an error occurs, please wait for 5 minutes before retrying.")
 
 
 
49
  with gr.Row():
50
  with gr.Column():
51
  matting_image_input = gr.Image(
52
  label="Input Image",
53
  type="filepath",
54
  )
 
55
  with gr.Row():
56
  matting_image_submit_btn = gr.Button(
57
  value="Estimate Matting", variant="primary"
@@ -80,7 +148,7 @@ def run_demo_server():
80
 
81
  matting_image_submit_btn.click(
82
  fn=process_image_check,
83
- inputs=matting_image_input,
84
  outputs=None,
85
  preprocess=False,
86
  queue=False,
 
4
  from pathlib import Path
5
  from gradio.utils import get_cache_folder
6
 
7
+ from PIL import Image
8
 
9
  class Examples(gr.helpers.Examples):
10
  def __init__(self, *args, cached_folder=None, **kwargs):
 
22
  hf_token=HF_TOKEN)
23
 
24
 
25
+ map_prompt = {
26
+ 'depth': '[[image2depth]]',
27
+ 'normal': '[[image2normal]]',
28
+ 'pose': '[[image2pose]]',
29
+ 'entity segmentation': '[[image2panoptic coarse]]',
30
+ 'point segmentation': '[[image2segmentation]]',
31
+ 'semantic segmentation': '[[image2semantic]]',
32
+ }
33
+
34
+ def download_additional_params(model_name, filename="add_params.bin"):
35
+ # 下载文件并返回文件路径
36
+ file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN)
37
+ return file_path
38
+
39
+ # 加载 additional_params.bin 文件
40
+ def load_additional_params(model_name):
41
+ # 下载 additional_params.bin
42
+ params_path = download_additional_params(model_name)
43
+
44
+ # 使用 torch.load() 加载文件内容
45
+ additional_params = torch.load(params_path, map_location='cpu')
46
+
47
+ # 返回加载的参数内容
48
+ return additional_params
49
+
50
+ def process_image_check(path_input, prompt):
51
  if path_input is None:
52
  raise gr.Error(
53
  "Missing image in the left pane: please upload an image first."
54
  )
55
+ if len(prompt) == 0:
56
+ raise gr.Error(
57
+ "At least 1 prediction type is needed."
58
+ )
59
+
60
+
61
+
62
+ def process_image_4(image_path, prompt):
63
+
64
+ inputs = []
65
+ for p in prompt:
66
+ image = Image.open(image_path)
67
+
68
+ w, h = image.size
69
+
70
+ coor_point = torch.zeros((1,5,2)).to(torch.float32)
71
+ point_labels = torch.zeros((1,5,1)).to(torch.float32)
72
 
73
+ image = image.resize((768, 768), Image.LANCZOS).convert('RGB')
74
+ to_tensor = transforms.ToTensor()
75
+ image = (to_tensor(image) - 0.5) * 2
76
+
77
+ cur_input = {
78
+ 'input_images': image.unsqueeze(0),
79
+ 'original_size': torch.tensor([[w,h]]),
80
+ 'target_size': torch.tensor([[768, 768]]),
81
+ 'prompt': [p],
82
+ 'coor_point': coor_point,
83
+ 'point_labels': point_labels,
84
+ 'generator': generator
85
+ }
86
+ inputs.append(cur_input)
87
+
88
+ return inputs
89
+
90
+
91
+ def infer_image_matting(image_path, prompt):
92
+ inputs = process_image_4(image_path, prompt)
93
+ return None
94
  return client.predict(
95
+ batch=inputs,
96
+ api_name="/inf"
97
  )
98
 
99
  def clear_cache():
100
  return None, None
101
 
102
  def run_demo_server():
103
+ options = ['depth', 'normal', 'entity', 'pose']
104
  gradio_theme = gr.themes.Default()
105
  with gr.Blocks(
106
  theme=gradio_theme,
107
  title="Matting",
108
  ) as demo:
109
  with gr.Row():
110
+ gr.Markdown("# Diception Demo")
111
  with gr.Row():
112
+ gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
113
+ with gr.Row():
114
+ checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
115
+
116
  with gr.Row():
117
  with gr.Column():
118
  matting_image_input = gr.Image(
119
  label="Input Image",
120
  type="filepath",
121
  )
122
+
123
  with gr.Row():
124
  matting_image_submit_btn = gr.Button(
125
  value="Estimate Matting", variant="primary"
 
148
 
149
  matting_image_submit_btn.click(
150
  fn=process_image_check,
151
+ inputs=[matting_image_input, checkbox_group],
152
  outputs=None,
153
  preprocess=False,
154
  queue=False,