aiqtech commited on
Commit
2eff0d8
·
verified ·
1 Parent(s): 099d2a4

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +752 -0
app-backup.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+ import os
6
+ from huggingface_hub import login, hf_hub_download
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pillow_heif
11
+ import spaces
12
+ import torch
13
+ from gradio_image_annotation import image_annotator
14
+ from gradio_imageslider import ImageSlider
15
+ from PIL import Image
16
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
+ from refiners.fluxion.utils import no_grad
18
+ from refiners.solutions import BoxSegmenter
19
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
+ from diffusers import FluxPipeline
21
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
22
+ import gc
23
+
24
+ from PIL import Image, ImageDraw, ImageFont
25
+ from PIL import Image
26
+ from gradio_client import Client, handle_file
27
+ import uuid
28
+
29
+
30
+ def clear_memory():
31
+ """메모리 정리 함수"""
32
+ gc.collect()
33
+ try:
34
+ if torch.cuda.is_available():
35
+ with torch.cuda.device(0): # 명시적으로 device 0 사용
36
+ torch.cuda.empty_cache()
37
+ except:
38
+ pass
39
+
40
+ # GPU 설정
41
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 명시적으로 cuda:0 지정
42
+
43
+ # GPU 설정을 try-except로 감싸기
44
+ if torch.cuda.is_available():
45
+ try:
46
+ with torch.cuda.device(0):
47
+ torch.cuda.empty_cache()
48
+ torch.backends.cudnn.benchmark = True
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ except:
51
+ print("Warning: Could not configure CUDA settings")
52
+
53
+ # 번역 모델 초기화
54
+ model_name = "Helsinki-NLP/opus-mt-ko-en"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
57
+ translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
58
+
59
+ def translate_to_english(text: str) -> str:
60
+ """한글 텍스트를 영어로 번역"""
61
+ try:
62
+ if any(ord('가') <= ord(char) <= ord('힣') for char in text):
63
+ translated = translator(text, max_length=128)[0]['translation_text']
64
+ print(f"Translated '{text}' to '{translated}'")
65
+ return translated
66
+ return text
67
+ except Exception as e:
68
+ print(f"Translation error: {str(e)}")
69
+ return text
70
+
71
+ BoundingBox = tuple[int, int, int, int]
72
+
73
+ pillow_heif.register_heif_opener()
74
+ pillow_heif.register_avif_opener()
75
+
76
+ # HF 토큰 설정
77
+ HF_TOKEN = os.getenv("HF_TOKEN")
78
+ if HF_TOKEN is None:
79
+ raise ValueError("Please set the HF_TOKEN environment variable")
80
+
81
+ try:
82
+ login(token=HF_TOKEN)
83
+ except Exception as e:
84
+ raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
85
+
86
+ # 모델 초기화
87
+ segmenter = BoxSegmenter(device="cpu")
88
+ segmenter.device = device
89
+ segmenter.model = segmenter.model.to(device=segmenter.device)
90
+
91
+ gd_model_path = "IDEA-Research/grounding-dino-base"
92
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
93
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
94
+ gd_model = gd_model.to(device=device)
95
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
96
+
97
+ # FLUX 파이프라인 초기화
98
+ pipe = FluxPipeline.from_pretrained(
99
+ "black-forest-labs/FLUX.1-dev",
100
+ torch_dtype=torch.float16,
101
+ use_auth_token=HF_TOKEN
102
+ )
103
+ pipe.enable_attention_slicing(slice_size="auto")
104
+
105
+ # LoRA 가중치 로드
106
+ pipe.load_lora_weights(
107
+ hf_hub_download(
108
+ "ByteDance/Hyper-SD",
109
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
110
+ use_auth_token=HF_TOKEN
111
+ )
112
+ )
113
+ pipe.fuse_lora(lora_scale=0.125)
114
+
115
+ # GPU 설정을 try-except로 감싸기
116
+ try:
117
+ if torch.cuda.is_available():
118
+ pipe = pipe.to("cuda:0") # 명시적으로 cuda:0 지정
119
+ except Exception as e:
120
+ print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
121
+
122
+ client = Client("NabeelShar/BiRefNet_for_text_writing")
123
+
124
+ class timer:
125
+ def __init__(self, method_name="timed process"):
126
+ self.method = method_name
127
+ def __enter__(self):
128
+ self.start = time.time()
129
+ print(f"{self.method} starts")
130
+ def __exit__(self, exc_type, exc_val, exc_tb):
131
+ end = time.time()
132
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
133
+
134
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
135
+ if not bboxes:
136
+ return None
137
+ for bbox in bboxes:
138
+ assert len(bbox) == 4
139
+ assert all(isinstance(x, int) for x in bbox)
140
+ return (
141
+ min(bbox[0] for bbox in bboxes),
142
+ min(bbox[1] for bbox in bboxes),
143
+ max(bbox[2] for bbox in bboxes),
144
+ max(bbox[3] for bbox in bboxes),
145
+ )
146
+
147
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
148
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
149
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
150
+
151
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
152
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
153
+ with no_grad():
154
+ outputs = gd_model(**inputs)
155
+ width, height = img.size
156
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
157
+ outputs,
158
+ inputs["input_ids"],
159
+ target_sizes=[(height, width)],
160
+ )[0]
161
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
162
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
163
+ return bbox_union(bboxes.numpy().tolist())
164
+
165
+ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
166
+ assert img.size == mask_img.size
167
+ img = img.convert("RGB")
168
+ mask_img = mask_img.convert("L")
169
+ if defringe:
170
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
171
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
172
+ img = Image.fromarray((foreground * 255).astype("uint8"))
173
+ result = Image.new("RGBA", img.size)
174
+ result.paste(img, (0, 0), mask_img)
175
+ return result
176
+
177
+
178
+ def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
179
+ """이미지 크기를 8의 배수로 조정하는 함수"""
180
+ new_width = ((width + 7) // 8) * 8
181
+ new_height = ((height + 7) // 8) * 8
182
+ return new_width, new_height
183
+
184
+ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
185
+ """선택된 비율에 따라 이미지 크기 계산"""
186
+ if aspect_ratio == "1:1":
187
+ return base_size, base_size
188
+ elif aspect_ratio == "16:9":
189
+ return base_size * 16 // 9, base_size
190
+ elif aspect_ratio == "9:16":
191
+ return base_size, base_size * 16 // 9
192
+ elif aspect_ratio == "4:3":
193
+ return base_size * 4 // 3, base_size
194
+ return base_size, base_size
195
+
196
+ @spaces.GPU(duration=20) # 40초에서 20초로 감소
197
+ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
198
+ try:
199
+ width, height = calculate_dimensions(aspect_ratio)
200
+ width, height = adjust_size_to_multiple_of_8(width, height)
201
+
202
+ max_size = 768
203
+ if width > max_size or height > max_size:
204
+ ratio = max_size / max(width, height)
205
+ width = int(width * ratio)
206
+ height = int(height * ratio)
207
+ width, height = adjust_size_to_multiple_of_8(width, height)
208
+
209
+ with timer("Background generation"):
210
+ try:
211
+ with torch.inference_mode():
212
+ image = pipe(
213
+ prompt=prompt,
214
+ width=width,
215
+ height=height,
216
+ num_inference_steps=8,
217
+ guidance_scale=4.0
218
+ ).images[0]
219
+ except Exception as e:
220
+ print(f"Pipeline error: {str(e)}")
221
+ return Image.new('RGB', (width, height), 'white')
222
+
223
+ return image
224
+ except Exception as e:
225
+ print(f"Background generation error: {str(e)}")
226
+ return Image.new('RGB', (512, 512), 'white')
227
+
228
+ def create_position_grid():
229
+ return """
230
+ <div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
231
+ <button class="position-btn" data-pos="top-left">↖</button>
232
+ <button class="position-btn" data-pos="top-center">↑</button>
233
+ <button class="position-btn" data-pos="top-right">↗</button>
234
+ <button class="position-btn" data-pos="middle-left">←</button>
235
+ <button class="position-btn" data-pos="middle-center">•</button>
236
+ <button class="position-btn" data-pos="middle-right">→</button>
237
+ <button class="position-btn" data-pos="bottom-left">↙</button>
238
+ <button class="position-btn" data-pos="bottom-center" data-default="true">↓</button>
239
+ <button class="position-btn" data-pos="bottom-right">↘</button>
240
+ </div>
241
+ """
242
+
243
+ def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
244
+ """오브젝트의 위치 계산"""
245
+ bg_width, bg_height = bg_size
246
+ obj_width, obj_height = obj_size
247
+
248
+ positions = {
249
+ "top-left": (0, 0),
250
+ "top-center": ((bg_width - obj_width) // 2, 0),
251
+ "top-right": (bg_width - obj_width, 0),
252
+ "middle-left": (0, (bg_height - obj_height) // 2),
253
+ "middle-center": ((bg_width - obj_width) // 2, (bg_height - obj_height) // 2),
254
+ "middle-right": (bg_width - obj_width, (bg_height - obj_height) // 2),
255
+ "bottom-left": (0, bg_height - obj_height),
256
+ "bottom-center": ((bg_width - obj_width) // 2, bg_height - obj_height),
257
+ "bottom-right": (bg_width - obj_width, bg_height - obj_height)
258
+ }
259
+
260
+ return positions.get(position, positions["bottom-center"])
261
+
262
+ def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
263
+ """오브젝트 크기 조정"""
264
+ width = int(image.width * scale_percent / 100)
265
+ height = int(image.height * scale_percent / 100)
266
+ return image.resize((width, height), Image.Resampling.LANCZOS)
267
+
268
+ def combine_with_background(foreground: Image.Image, background: Image.Image,
269
+ position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
270
+ """전경과 배경 합성 함수"""
271
+ # 배경 이미지 준비
272
+ result = background.convert('RGBA')
273
+
274
+ # 오브젝트 크기 조정
275
+ scaled_foreground = resize_object(foreground, scale_percent)
276
+
277
+ # 오브젝트 위치 계산
278
+ x, y = calculate_object_position(position, result.size, scaled_foreground.size)
279
+
280
+ # 합성
281
+ result.paste(scaled_foreground, (x, y), scaled_foreground)
282
+ return result
283
+
284
+ @spaces.GPU(duration=30) # 120초에서 30초로 감소
285
+ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
286
+ time_log: list[str] = []
287
+ try:
288
+ if isinstance(prompt, str):
289
+ t0 = time.time()
290
+ bbox = gd_detect(img, prompt)
291
+ time_log.append(f"detect: {time.time() - t0}")
292
+ if not bbox:
293
+ print(time_log[0])
294
+ raise gr.Error("No object detected")
295
+ else:
296
+ bbox = prompt
297
+ t0 = time.time()
298
+ mask = segmenter(img, bbox)
299
+ time_log.append(f"segment: {time.time() - t0}")
300
+ return mask, bbox, time_log
301
+ except Exception as e:
302
+ print(f"GPU process error: {str(e)}")
303
+ raise
304
+
305
+ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
306
+ try:
307
+ # 입력 이미지 크기 제한
308
+ max_size = 1024
309
+ if img.width > max_size or img.height > max_size:
310
+ ratio = max_size / max(img.width, img.height)
311
+ new_size = (int(img.width * ratio), int(img.height * ratio))
312
+ img = img.resize(new_size, Image.LANCZOS)
313
+
314
+ # CUDA 메모리 관리 수정
315
+ try:
316
+ if torch.cuda.is_available():
317
+ current_device = torch.cuda.current_device()
318
+ with torch.cuda.device(current_device):
319
+ torch.cuda.empty_cache()
320
+ except Exception as e:
321
+ print(f"CUDA memory management failed: {e}")
322
+
323
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
324
+ mask, bbox, time_log = _gpu_process(img, prompt)
325
+ masked_alpha = apply_mask(img, mask, defringe=True)
326
+
327
+ if bg_prompt:
328
+ background = generate_background(bg_prompt, aspect_ratio)
329
+ combined = background
330
+ else:
331
+ combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
332
+
333
+ clear_memory()
334
+
335
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
336
+ combined.save(temp.name)
337
+ return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
338
+ except Exception as e:
339
+ clear_memory()
340
+ print(f"Processing error: {str(e)}")
341
+ raise gr.Error(f"Processing failed: {str(e)}")
342
+
343
+ def on_change_bbox(prompts: dict[str, Any] | None):
344
+ return gr.update(interactive=prompts is not None)
345
+
346
+
347
+ def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
348
+ return gr.update(interactive=bool(img and prompt))
349
+
350
+
351
+
352
+ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
353
+ aspect_ratio: str = "1:1", position: str = "bottom-center",
354
+ scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
355
+ try:
356
+ if img is None or prompt.strip() == "":
357
+ raise gr.Error("Please provide both image and prompt")
358
+
359
+ print(f"Processing with position: {position}, scale: {scale_percent}")
360
+
361
+ try:
362
+ prompt = translate_to_english(prompt)
363
+ if bg_prompt:
364
+ bg_prompt = translate_to_english(bg_prompt)
365
+ except Exception as e:
366
+ print(f"Translation error (continuing with original text): {str(e)}")
367
+
368
+ results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
369
+
370
+ if bg_prompt:
371
+ try:
372
+ combined = combine_with_background(
373
+ foreground=results[2],
374
+ background=results[1],
375
+ position=position,
376
+ scale_percent=scale_percent
377
+ )
378
+ print(f"Combined image created with position: {position}")
379
+ return combined, results[2]
380
+ except Exception as e:
381
+ print(f"Combination error: {str(e)}")
382
+ return results[1], results[2]
383
+
384
+ return results[1], results[2]
385
+ except Exception as e:
386
+ print(f"Error in process_prompt: {str(e)}")
387
+ raise gr.Error(str(e))
388
+ finally:
389
+ clear_memory()
390
+
391
+ def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
392
+ try:
393
+ if img is None or box_input.strip() == "":
394
+ raise gr.Error("Please provide both image and bounding box coordinates")
395
+
396
+ try:
397
+ coords = eval(box_input)
398
+ if not isinstance(coords, list) or len(coords) != 4:
399
+ raise ValueError("Invalid box format")
400
+ bbox = tuple(int(x) for x in coords)
401
+ except:
402
+ raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
403
+
404
+ # Process the image
405
+ results, _ = _process(img, bbox)
406
+
407
+ # 합성된 이미지와 추출된 이미지만 반환
408
+ return results[1], results[2]
409
+ except Exception as e:
410
+ raise gr.Error(str(e))
411
+
412
+ # Event handler functions 수정
413
+ def update_process_button(img, prompt):
414
+ return gr.update(
415
+ interactive=bool(img and prompt),
416
+ variant="primary" if bool(img and prompt) else "secondary"
417
+ )
418
+
419
+ def update_box_button(img, box_input):
420
+ try:
421
+ if img and box_input:
422
+ coords = eval(box_input)
423
+ if isinstance(coords, list) and len(coords) == 4:
424
+ return gr.update(interactive=True, variant="primary")
425
+ return gr.update(interactive=False, variant="secondary")
426
+ except:
427
+ return gr.update(interactive=False, variant="secondary")
428
+
429
+
430
+ # CSS 정의
431
+ css = """
432
+ footer {display: none}
433
+ .main-title {
434
+ text-align: center;
435
+ margin: 2em 0;
436
+ padding: 1em;
437
+ background: #f7f7f7;
438
+ border-radius: 10px;
439
+ }
440
+ .main-title h1 {
441
+ color: #2196F3;
442
+ font-size: 2.5em;
443
+ margin-bottom: 0.5em;
444
+ }
445
+ .main-title p {
446
+ color: #666;
447
+ font-size: 1.2em;
448
+ }
449
+ .container {
450
+ max-width: 1200px;
451
+ margin: auto;
452
+ padding: 20px;
453
+ }
454
+ .tabs {
455
+ margin-top: 1em;
456
+ }
457
+ .input-group {
458
+ background: white;
459
+ padding: 1em;
460
+ border-radius: 8px;
461
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
462
+ }
463
+ .output-group {
464
+ background: white;
465
+ padding: 1em;
466
+ border-radius: 8px;
467
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
468
+ }
469
+ button.primary {
470
+ background: #2196F3;
471
+ border: none;
472
+ color: white;
473
+ padding: 0.5em 1em;
474
+ border-radius: 4px;
475
+ cursor: pointer;
476
+ transition: background 0.3s ease;
477
+ }
478
+ button.primary:hover {
479
+ background: #1976D2;
480
+ }
481
+ .position-btn {
482
+ transition: all 0.3s ease;
483
+ }
484
+ .position-btn:hover {
485
+ background-color: #e3f2fd;
486
+ }
487
+ .position-btn.selected {
488
+ background-color: #2196F3;
489
+ color: white;
490
+ }
491
+ """
492
+
493
+
494
+
495
+ def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
496
+ """Helper function to draw text with stroke"""
497
+ # Draw the stroke/outline
498
+ for adj_x in range(-stroke_width, stroke_width + 1):
499
+ for adj_y in range(-stroke_width, stroke_width + 1):
500
+ draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color)
501
+
502
+ def remove_background(image):
503
+ # Save the image to a specific location
504
+ filename = f"image_{uuid.uuid4()}.png" # Generates a universally unique identifier (UUID) for the filename
505
+ image.save(filename)
506
+ # Call gradio client for background removal
507
+ result = client.predict(images=handle_file(filename), api_name="/image")
508
+ return Image.open(result[0])
509
+
510
+ def superimpose(image_with_text, overlay_image):
511
+ # Open image as RGBA to handle transparency
512
+ overlay_image = overlay_image.convert("RGBA")
513
+ # Paste overlay on the background
514
+ image_with_text.paste(overlay_image, (0, 0), overlay_image)
515
+ # Save the final image
516
+ # image_with_text.save("output_image.png")
517
+ return image_with_text
518
+
519
+ def add_text_to_image(
520
+ input_image,
521
+ text,
522
+ font_size,
523
+ color,
524
+ opacity,
525
+ x_position,
526
+ y_position,
527
+ thickness
528
+ ):
529
+ """
530
+ Add text to an image with customizable properties
531
+ """
532
+ # Convert gradio image (numpy array) to PIL Image
533
+ if input_image is None:
534
+ return None
535
+
536
+ image = Image.fromarray(input_image)
537
+ # remove background
538
+ overlay_image = remove_background(image)
539
+
540
+ # Create a transparent overlay for the text
541
+ txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
542
+ draw = ImageDraw.Draw(txt_overlay)
543
+
544
+ # Create a font with specified size
545
+ try:
546
+ font = ImageFont.truetype("DejaVuSans.ttf", int(font_size))
547
+ except:
548
+ # If DejaVu font is not found, try to use Arial or default
549
+ try:
550
+ font = ImageFont.truetype("arial.ttf", int(font_size))
551
+ except:
552
+ print("Using default font as system fonts not found")
553
+ font = ImageFont.load_default()
554
+
555
+ # Convert color name to RGB
556
+ color_map = {
557
+ 'White': (255, 255, 255),
558
+ 'Black': (0, 0, 0),
559
+ 'Red': (255, 0, 0),
560
+ 'Green': (0, 255, 0),
561
+ 'Blue': (0, 0, 255),
562
+ 'Yellow': (255, 255, 0),
563
+ 'Purple': (128, 0, 128)
564
+ }
565
+ rgb_color = color_map.get(color, (255, 255, 255))
566
+
567
+ # Get text size for positioning
568
+ text_bbox = draw.textbbox((0, 0), text, font=font)
569
+ text_width = text_bbox[2] - text_bbox[0]
570
+ text_height = text_bbox[3] - text_bbox[1]
571
+
572
+ # Calculate actual x and y positions based on percentages
573
+ actual_x = int((image.width - text_width) * (x_position / 100))
574
+ actual_y = int((image.height - text_height) * (y_position / 100))
575
+
576
+ # Create final color with opacity
577
+ text_color = (*rgb_color, int(opacity))
578
+
579
+ # Draw the text with stroke for thickness
580
+ add_text_with_stroke(
581
+ draw,
582
+ text,
583
+ actual_x,
584
+ actual_y,
585
+ font,
586
+ text_color,
587
+ int(thickness)
588
+ )
589
+
590
+ # Combine the original image with the text overlay
591
+ if image.mode != 'RGBA':
592
+ image = image.convert('RGBA')
593
+ output_image = Image.alpha_composite(image, txt_overlay)
594
+
595
+ # Convert back to RGB for display
596
+ output_image = output_image.convert('RGB')
597
+
598
+ # superimpose images
599
+ output_image = superimpose(output_image, overlay_image)
600
+
601
+ # Convert PIL image back to numpy array for Gradio
602
+ return np.array(output_image)
603
+
604
+ # UI 구성
605
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
606
+ gr.HTML("""
607
+ <div class="main-title">
608
+ <h1>🎨GiniGen Canvas</h1>
609
+ <p>AI Integrated Image Creator: Extract objects, generate backgrounds, and adjust ratios and positions to create complete images with AI.</p>
610
+ </div>
611
+ """)
612
+
613
+ with gr.Row():
614
+ with gr.Column(scale=1):
615
+ input_image = gr.Image(
616
+ type="pil",
617
+ label="Upload Image",
618
+ interactive=True
619
+ )
620
+ text_prompt = gr.Textbox(
621
+ label="Object to Extract",
622
+ placeholder="Enter what you want to extract...",
623
+ interactive=True
624
+ )
625
+ with gr.Row():
626
+ bg_prompt = gr.Textbox(
627
+ label="Background Prompt (optional)",
628
+ placeholder="Describe the background...",
629
+ interactive=True,
630
+ scale=3
631
+ )
632
+ aspect_ratio = gr.Dropdown(
633
+ choices=["1:1", "16:9", "9:16", "4:3"],
634
+ value="1:1",
635
+ label="Aspect Ratio",
636
+ interactive=True,
637
+ visible=True,
638
+ scale=1
639
+ )
640
+
641
+ with gr.Row(visible=False) as object_controls:
642
+ with gr.Column(scale=1):
643
+ with gr.Row():
644
+ position = gr.State(value="bottom-center")
645
+ btn_top_left = gr.Button("↖")
646
+ btn_top_center = gr.Button("↑")
647
+ btn_top_right = gr.Button("↗")
648
+ with gr.Row():
649
+ btn_middle_left = gr.Button("←")
650
+ btn_middle_center = gr.Button("•")
651
+ btn_middle_right = gr.Button("→")
652
+ with gr.Row():
653
+ btn_bottom_left = gr.Button("↙")
654
+ btn_bottom_center = gr.Button("↓")
655
+ btn_bottom_right = gr.Button("↘")
656
+ with gr.Column(scale=1):
657
+ scale_slider = gr.Slider(
658
+ minimum=10,
659
+ maximum=200,
660
+ value=50,
661
+ step=5,
662
+ label="Object Size (%)"
663
+ )
664
+
665
+ process_btn = gr.Button(
666
+ "Process",
667
+ variant="primary",
668
+ interactive=False
669
+ )
670
+
671
+ # 각 버튼에 대한 클릭 이벤트 처리
672
+ def update_position(new_position):
673
+ return new_position
674
+
675
+ btn_top_left.click(fn=lambda: update_position("top-left"), outputs=position)
676
+ btn_top_center.click(fn=lambda: update_position("top-center"), outputs=position)
677
+ btn_top_right.click(fn=lambda: update_position("top-right"), outputs=position)
678
+ btn_middle_left.click(fn=lambda: update_position("middle-left"), outputs=position)
679
+ btn_middle_center.click(fn=lambda: update_position("middle-center"), outputs=position)
680
+ btn_middle_right.click(fn=lambda: update_position("middle-right"), outputs=position)
681
+ btn_bottom_left.click(fn=lambda: update_position("bottom-left"), outputs=position)
682
+ btn_bottom_center.click(fn=lambda: update_position("bottom-center"), outputs=position)
683
+ btn_bottom_right.click(fn=lambda: update_position("bottom-right"), outputs=position)
684
+
685
+ with gr.Column(scale=1):
686
+ with gr.Row():
687
+ combined_image = gr.Image(
688
+ label="Combined Result",
689
+ show_download_button=True,
690
+ type="pil",
691
+ height=512
692
+ )
693
+ with gr.Row():
694
+ extracted_image = gr.Image(
695
+ label="Extracted Object",
696
+ show_download_button=True,
697
+ type="pil",
698
+ height=256
699
+ )
700
+
701
+ # Event bindings
702
+ input_image.change(
703
+ fn=update_process_button,
704
+ inputs=[input_image, text_prompt],
705
+ outputs=process_btn,
706
+ queue=False
707
+ )
708
+
709
+ text_prompt.change(
710
+ fn=update_process_button,
711
+ inputs=[input_image, text_prompt],
712
+ outputs=process_btn,
713
+ queue=False
714
+ )
715
+
716
+ def update_controls(bg_prompt):
717
+ """배경 프롬프트 입력 여부에 따라 컨트롤 표시 업데이트"""
718
+ is_visible = bool(bg_prompt)
719
+ return [
720
+ gr.update(visible=is_visible), # aspect_ratio
721
+ gr.update(visible=is_visible), # object_controls
722
+ ]
723
+
724
+ bg_prompt.change(
725
+ fn=update_controls,
726
+ inputs=bg_prompt,
727
+ outputs=[aspect_ratio, object_controls],
728
+ queue=False
729
+ )
730
+
731
+ process_btn.click(
732
+ fn=process_prompt,
733
+ inputs=[
734
+ input_image,
735
+ text_prompt,
736
+ bg_prompt,
737
+ aspect_ratio,
738
+ position,
739
+ scale_slider
740
+ ],
741
+ outputs=[combined_image, extracted_image],
742
+ queue=True
743
+ )
744
+
745
+
746
+ demo.queue(max_size=5) # 큐 크기 제한
747
+ demo.launch(
748
+ server_name="0.0.0.0",
749
+ server_port=7860,
750
+ share=False,
751
+ max_threads=2 # 스레드 수 제한
752
+ )