ginipick commited on
Commit
2447d97
·
verified ·
1 Parent(s): 4fef9b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -282
app.py CHANGED
@@ -1,219 +1,4 @@
1
- import tempfile
2
- import time
3
- from collections.abc import Sequence
4
- from typing import Any, cast
5
- import os
6
- from huggingface_hub import login, hf_hub_download # hf_hub_download 추가
7
-
8
- import gradio as gr
9
- import numpy as np
10
- import pillow_heif
11
- import spaces
12
- import torch
13
- from gradio_image_annotation import image_annotator
14
- from gradio_imageslider import ImageSlider
15
- from PIL import Image
16
- from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
- from refiners.fluxion.utils import no_grad
18
- from refiners.solutions import BoxSegmenter
19
- from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
- from diffusers import FluxPipeline
21
-
22
- BoundingBox = tuple[int, int, int, int]
23
-
24
- # 초기화 및 설정
25
- pillow_heif.register_heif_opener()
26
- pillow_heif.register_avif_opener()
27
-
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
-
30
- # HF 토큰 설정
31
- HF_TOKEN = os.getenv("HF_TOKEN")
32
- if HF_TOKEN is None:
33
- raise ValueError("Please set the HF_TOKEN environment variable")
34
-
35
- try:
36
- login(token=HF_TOKEN)
37
- except Exception as e:
38
- raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
39
-
40
- # 모델 초기화
41
- segmenter = BoxSegmenter(device="cpu")
42
- segmenter.device = device
43
- segmenter.model = segmenter.model.to(device=segmenter.device)
44
-
45
- gd_model_path = "IDEA-Research/grounding-dino-base"
46
- gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
47
- gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
48
- gd_model = gd_model.to(device=device)
49
- assert isinstance(gd_model, GroundingDinoForObjectDetection)
50
-
51
- # FLUX 파이프라인 초기화
52
- pipe = FluxPipeline.from_pretrained(
53
- "black-forest-labs/FLUX.1-dev",
54
- torch_dtype=torch.bfloat16,
55
- use_auth_token=HF_TOKEN
56
- )
57
- pipe.load_lora_weights(
58
- hf_hub_download(
59
- "ByteDance/Hyper-SD",
60
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
61
- use_auth_token=HF_TOKEN
62
- )
63
- )
64
- pipe.fuse_lora(lora_scale=0.125)
65
- pipe.to(device="cuda", dtype=torch.bfloat16)
66
-
67
- def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
68
- if not bboxes:
69
- return None
70
- for bbox in bboxes:
71
- assert len(bbox) == 4
72
- assert all(isinstance(x, int) for x in bbox)
73
- return (
74
- min(bbox[0] for bbox in bboxes),
75
- min(bbox[1] for bbox in bboxes),
76
- max(bbox[2] for bbox in bboxes),
77
- max(bbox[3] for bbox in bboxes),
78
- )
79
-
80
- def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
81
- x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
82
- return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
83
-
84
- def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
85
- inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
86
- with no_grad():
87
- outputs = gd_model(**inputs)
88
- width, height = img.size
89
- results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
90
- outputs,
91
- inputs["input_ids"],
92
- target_sizes=[(height, width)],
93
- )[0]
94
- assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
95
- bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
96
- return bbox_union(bboxes.numpy().tolist())
97
-
98
- def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
99
- assert img.size == mask_img.size
100
- img = img.convert("RGB")
101
- mask_img = mask_img.convert("L")
102
- if defringe:
103
- rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
104
- foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
105
- img = Image.fromarray((foreground * 255).astype("uint8"))
106
- result = Image.new("RGBA", img.size)
107
- result.paste(img, (0, 0), mask_img)
108
- return result
109
-
110
- def generate_background(prompt: str, width: int, height: int) -> Image.Image:
111
- """배경 이미지 생성 함수"""
112
- try:
113
- with timer("Background generation"):
114
- image = pipe(
115
- prompt=prompt,
116
- width=width,
117
- height=height,
118
- num_inference_steps=8,
119
- guidance_scale=4.0,
120
- ).images[0]
121
- return image
122
- except Exception as e:
123
- raise gr.Error(f"Background generation failed: {str(e)}")
124
-
125
- def combine_with_background(foreground: Image.Image, background: Image.Image) -> Image.Image:
126
- """전경과 배경 합성 함수"""
127
- background = background.resize(foreground.size)
128
- return Image.alpha_composite(background.convert('RGBA'), foreground)
129
-
130
- @spaces.GPU
131
- def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
132
- time_log: list[str] = []
133
- if isinstance(prompt, str):
134
- t0 = time.time()
135
- bbox = gd_detect(img, prompt)
136
- time_log.append(f"detect: {time.time() - t0}")
137
- if not bbox:
138
- print(time_log[0])
139
- raise gr.Error("No object detected")
140
- else:
141
- bbox = prompt
142
- t0 = time.time()
143
- mask = segmenter(img, bbox)
144
- time_log.append(f"segment: {time.time() - t0}")
145
- return mask, bbox, time_log
146
-
147
- def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None) -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
148
- if img.width > 2048 or img.height > 2048:
149
- orig_res = max(img.width, img.height)
150
- img.thumbnail((2048, 2048))
151
- if isinstance(prompt, tuple):
152
- x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
153
- prompt = (x0, y0, x1, y1)
154
-
155
- mask, bbox, time_log = _gpu_process(img, prompt)
156
- masked_alpha = apply_mask(img, mask, defringe=True)
157
-
158
- if bg_prompt:
159
- try:
160
- background = generate_background(bg_prompt, img.width, img.height)
161
- combined = combine_with_background(masked_alpha, background)
162
- except Exception as e:
163
- raise gr.Error(f"Background processing failed: {str(e)}")
164
- else:
165
- combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
166
-
167
- thresholded = mask.point(lambda p: 255 if p > 10 else 0)
168
- bbox = thresholded.getbbox()
169
- to_dl = masked_alpha.crop(bbox)
170
-
171
- temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
172
- to_dl.save(temp, format="PNG")
173
- temp.close()
174
-
175
- return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
176
-
177
- def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
178
- assert isinstance(img := prompts["image"], Image.Image)
179
- assert isinstance(boxes := prompts["boxes"], list)
180
- if len(boxes) == 1:
181
- assert isinstance(box := boxes[0], dict)
182
- bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
183
- else:
184
- assert len(boxes) == 0
185
- bbox = None
186
- return _process(img, bbox)
187
-
188
- def on_change_bbox(prompts: dict[str, Any] | None):
189
- return gr.update(interactive=prompts is not None)
190
-
191
- def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
192
- return _process(img, prompt, bg_prompt)
193
-
194
- def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
195
- return gr.update(interactive=bool(img and prompt))
196
-
197
- # CSS 스타일 정의
198
- css = """
199
- footer {display: none}
200
- .main-title {
201
- text-align: center;
202
- margin: 2em 0;
203
- }
204
- .main-title h1 {
205
- color: #2196F3;
206
- font-size: 2.5em;
207
- }
208
- .container {
209
- max-width: 1200px;
210
- margin: auto;
211
- padding: 20px;
212
- }
213
- """
214
-
215
- # Gradio UI
216
- # Gradio UI
217
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
218
  gr.HTML("""
219
  <div class="main-title">
@@ -224,86 +9,87 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
224
 
225
  with gr.Tabs() as tabs:
226
  # Text 탭
227
- with gr.Tab("Extract by Text", id="tab_prompt"):
228
  with gr.Row(equal_height=True):
229
  with gr.Column(scale=1, min_width=400):
230
- gr.HTML("<h3>📥 Input Section</h3>")
231
- iimg = gr.Image(
232
  type="pil",
233
- label="Upload Image"
 
234
  )
235
  with gr.Group():
236
- prompt = gr.Textbox(
237
- label="🎯 Object to Extract",
238
- placeholder="Enter what you want to extract..."
 
239
  )
240
- bg_prompt = gr.Textbox(
241
- label="🖼️ Background Generation Prompt (optional)",
242
- placeholder="Describe the background you want..."
 
243
  )
244
- btn = gr.Button(
245
- "🚀 Process Image",
246
- variant="primary",
247
- interactive=False
248
  )
249
 
250
  with gr.Column(scale=1, min_width=400):
251
- gr.HTML("<h3>📤 Output Section</h3>")
252
- oimg = ImageSlider(
253
  label="Results Preview",
254
  show_download_button=False
255
  )
256
- dlbt = gr.DownloadButton(
257
- "💾 Download Result",
258
- interactive=False
259
  )
260
 
261
- with gr.Accordion("📚 Examples", open=False):
262
- examples = [
 
263
  ["examples/text.jpg", "text", "white background"],
264
  ["examples/black-lamp.jpg", "black lamp", "minimalist interior"]
265
  ]
266
- ex = gr.Examples(
267
- examples=examples,
268
- inputs=[iimg, prompt, bg_prompt],
269
- outputs=[oimg, dlbt],
270
  fn=process_prompt,
271
  cache_examples=True
272
  )
273
 
274
  # Bounding Box 탭
275
- with gr.Tab("📏 Extract by Box", id="tab_bb"):
276
  with gr.Row(equal_height=True):
277
  with gr.Column(scale=1, min_width=400):
278
- gr.HTML("<h3>📥 Input Section</h3>")
279
- annotator = image_annotator(
280
  image_type="pil",
281
- disable_edit_boxes=True,
282
  show_download_button=False,
283
  show_share_button=False,
284
  single_box=True,
285
- label="Draw Box Around Object"
 
286
  )
287
- btn_bb = gr.Button(
288
- "✂️ Extract Selection",
289
- variant="primary",
290
- interactive=False
291
  )
292
 
293
  with gr.Column(scale=1, min_width=400):
294
- gr.HTML("<h3>📤 Output Section</h3>")
295
- oimg_bb = ImageSlider(
296
  label="Results Preview",
297
  show_download_button=False
298
  )
299
- dlbt_bb = gr.DownloadButton(
300
- "💾 Download Result",
301
- interactive=False
302
  )
303
 
304
  # Bounding Box Examples
305
- with gr.Accordion("📚 Examples", open=False):
306
- examples_bb = [
307
  {
308
  "image": "examples/text.jpg",
309
  "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}]
@@ -313,42 +99,101 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
313
  "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}]
314
  }
315
  ]
316
- ex_bb = gr.Examples(
317
- examples=examples_bb,
318
- inputs=[annotator], # annotator가 이미 정의된 후에 사용
319
- outputs=[oimg_bb, dlbt_bb],
320
  fn=process_bbox,
321
  cache_examples=True
322
  )
323
- # Event handlers 부분 수정
 
 
 
 
324
  # Text 탭 이벤트
325
- for inp in [iimg, prompt]:
326
- inp.change(
327
- fn=on_change_prompt,
328
- inputs=[iimg, prompt, bg_prompt],
329
- outputs=[btn],
330
- )
331
- btn.click(
 
 
 
 
332
  fn=process_prompt,
333
- inputs=[iimg, prompt, bg_prompt],
334
- outputs=[oimg, dlbt],
335
- api_name=False,
336
  )
337
 
338
  # Bounding Box 탭 이벤트
339
- annotator.change(
340
- fn=on_change_bbox,
341
- inputs=[annotator],
342
- outputs=[btn_bb],
343
  )
344
- btn_bb.click(
345
  fn=process_bbox,
346
- inputs=[annotator],
347
- outputs=[oimg_bb, dlbt_bb],
348
- api_name=False,
349
  )
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
 
352
  demo.queue(max_size=30, api_open=False)
353
  demo.launch(
354
  show_api=False,
@@ -356,5 +201,4 @@ demo.launch(
356
  server_name="0.0.0.0",
357
  server_port=7860,
358
  show_error=True
359
- )
360
-
 
1
+ # Gradio UI 부분 수정
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
3
  gr.HTML("""
4
  <div class="main-title">
 
9
 
10
  with gr.Tabs() as tabs:
11
  # Text 탭
12
+ with gr.Tab("Extract by Text", id="tab_prompt"):
13
  with gr.Row(equal_height=True):
14
  with gr.Column(scale=1, min_width=400):
15
+ gr.HTML("<h3>Input Section</h3>")
16
+ input_image = gr.Image(
17
  type="pil",
18
+ label="Upload Image",
19
+ interactive=True
20
  )
21
  with gr.Group():
22
+ text_prompt = gr.Textbox(
23
+ label="Object to Extract",
24
+ placeholder="Enter what you want to extract...",
25
+ interactive=True
26
  )
27
+ background_prompt = gr.Textbox(
28
+ label="Background Generation Prompt (optional)",
29
+ placeholder="Describe the background you want...",
30
+ interactive=True
31
  )
32
+ process_btn = gr.Button(
33
+ "Process Image",
34
+ variant="primary"
 
35
  )
36
 
37
  with gr.Column(scale=1, min_width=400):
38
+ gr.HTML("<h3>Output Section</h3>")
39
+ output_image = ImageSlider(
40
  label="Results Preview",
41
  show_download_button=False
42
  )
43
+ download_btn = gr.DownloadButton(
44
+ "Download Result"
 
45
  )
46
 
47
+ # Text Examples
48
+ with gr.Accordion("Examples", open=False):
49
+ text_examples = [
50
  ["examples/text.jpg", "text", "white background"],
51
  ["examples/black-lamp.jpg", "black lamp", "minimalist interior"]
52
  ]
53
+ gr.Examples(
54
+ examples=text_examples,
55
+ inputs=[input_image, text_prompt, background_prompt],
56
+ outputs=[output_image, download_btn],
57
  fn=process_prompt,
58
  cache_examples=True
59
  )
60
 
61
  # Bounding Box 탭
62
+ with gr.Tab("Extract by Box", id="tab_bb"):
63
  with gr.Row(equal_height=True):
64
  with gr.Column(scale=1, min_width=400):
65
+ gr.HTML("<h3>Input Section</h3>")
66
+ box_annotator = image_annotator(
67
  image_type="pil",
68
+ disable_edit_boxes=False, # 편집 가능하도록 변경
69
  show_download_button=False,
70
  show_share_button=False,
71
  single_box=True,
72
+ label="Draw Box Around Object",
73
+ interactive=True
74
  )
75
+ box_process_btn = gr.Button(
76
+ "Extract Selection",
77
+ variant="primary"
 
78
  )
79
 
80
  with gr.Column(scale=1, min_width=400):
81
+ gr.HTML("<h3>Output Section</h3>")
82
+ box_output_image = ImageSlider(
83
  label="Results Preview",
84
  show_download_button=False
85
  )
86
+ box_download_btn = gr.DownloadButton(
87
+ "Download Result"
 
88
  )
89
 
90
  # Bounding Box Examples
91
+ with gr.Accordion("Examples", open=False):
92
+ box_examples = [
93
  {
94
  "image": "examples/text.jpg",
95
  "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}]
 
99
  "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}]
100
  }
101
  ]
102
+ gr.Examples(
103
+ examples=box_examples,
104
+ inputs=[box_annotator],
105
+ outputs=[box_output_image, box_download_btn],
106
  fn=process_bbox,
107
  cache_examples=True
108
  )
109
+
110
+ # Event handlers
111
+ def update_button_state(img, prompt):
112
+ return gr.Button.update(interactive=bool(img and prompt))
113
+
114
  # Text 탭 이벤트
115
+ input_image.change(
116
+ fn=update_button_state,
117
+ inputs=[input_image, text_prompt],
118
+ outputs=[process_btn]
119
+ )
120
+ text_prompt.change(
121
+ fn=update_button_state,
122
+ inputs=[input_image, text_prompt],
123
+ outputs=[process_btn]
124
+ )
125
+ process_btn.click(
126
  fn=process_prompt,
127
+ inputs=[input_image, text_prompt, background_prompt],
128
+ outputs=[output_image, download_btn]
 
129
  )
130
 
131
  # Bounding Box 탭 이벤트
132
+ box_annotator.change(
133
+ fn=lambda x: gr.Button.update(interactive=bool(x)),
134
+ inputs=[box_annotator],
135
+ outputs=[box_process_btn]
136
  )
137
+ box_process_btn.click(
138
  fn=process_bbox,
139
+ inputs=[box_annotator],
140
+ outputs=[box_output_image, box_download_btn]
 
141
  )
142
 
143
+ # CSS 스타일 업데이트
144
+ css = """
145
+ footer {display: none}
146
+ .main-title {
147
+ text-align: center;
148
+ margin: 2em 0;
149
+ padding: 1em;
150
+ background: #f7f7f7;
151
+ border-radius: 10px;
152
+ }
153
+ .main-title h1 {
154
+ color: #2196F3;
155
+ font-size: 2.5em;
156
+ margin-bottom: 0.5em;
157
+ }
158
+ .main-title p {
159
+ color: #666;
160
+ font-size: 1.2em;
161
+ }
162
+ .container {
163
+ max-width: 1200px;
164
+ margin: auto;
165
+ padding: 20px;
166
+ }
167
+ .tabs {
168
+ margin-top: 1em;
169
+ }
170
+ .input-group {
171
+ background: white;
172
+ padding: 1em;
173
+ border-radius: 8px;
174
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
175
+ }
176
+ .output-group {
177
+ background: white;
178
+ padding: 1em;
179
+ border-radius: 8px;
180
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
181
+ }
182
+ button.primary {
183
+ background: #2196F3;
184
+ border: none;
185
+ color: white;
186
+ padding: 0.5em 1em;
187
+ border-radius: 4px;
188
+ cursor: pointer;
189
+ transition: background 0.3s ease;
190
+ }
191
+ button.primary:hover {
192
+ background: #1976D2;
193
+ }
194
+ """
195
 
196
+ # Launch settings
197
  demo.queue(max_size=30, api_open=False)
198
  demo.launch(
199
  show_api=False,
 
201
  server_name="0.0.0.0",
202
  server_port=7860,
203
  show_error=True
204
+ )