Ashoka74 commited on
Commit
04caf6e
ยท
verified ยท
1 Parent(s): d68fbf4

Create merged_app2.py

Browse files
Files changed (1) hide show
  1. merged_app2.py +1880 -0
merged_app2.py ADDED
@@ -0,0 +1,1880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image, ImageDraw
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+
11
+ import spaces
12
+ import argparse
13
+ import random
14
+
15
+ import os
16
+ import math
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import safetensors.torch as sf
21
+ import datetime
22
+ from pathlib import Path
23
+ from io import BytesIO
24
+
25
+ from PIL import Image
26
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
27
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
28
+ from diffusers.models.attention_processor import AttnProcessor2_0
29
+ from transformers import CLIPTextModel, CLIPTokenizer
30
+ import dds_cloudapi_sdk
31
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
32
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
33
+ from dds_cloudapi_sdk.tasks import DetectionTarget
34
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
35
+ from transformers import AutoModelForImageSegmentation
36
+
37
+
38
+ from enum import Enum
39
+ from torch.hub import download_url_to_file
40
+ import tempfile
41
+
42
+ from sam2.build_sam import build_sam2
43
+
44
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
45
+ import cv2
46
+
47
+ from transformers import AutoModelForImageSegmentation
48
+ from inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
49
+ from torchvision import transforms
50
+
51
+
52
+ from typing import Optional
53
+
54
+ from depth_anything_v2.dpt import DepthAnythingV2
55
+
56
+ import httpx
57
+
58
+
59
+ import gradio as gr
60
+ import torch
61
+ from diffusers import FluxFillPipeline
62
+ from diffusers.utils import load_image
63
+ from PIL import Image, ImageDraw
64
+ import numpy as np
65
+ import spaces
66
+ from huggingface_hub import hf_hub_download
67
+
68
+
69
+
70
+
71
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
72
+ NUM_VIEWS = 6
73
+ HEIGHT = 768
74
+ WIDTH = 768
75
+ MAX_SEED = np.iinfo(np.int32).max
76
+
77
+
78
+
79
+ import supervision as sv
80
+ import torch
81
+ from PIL import Image
82
+
83
+ import logging
84
+
85
+ # Configure logging
86
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
87
+
88
+ transform_image = transforms.Compose(
89
+ [
90
+ transforms.Resize((1024, 1024)),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
93
+ ]
94
+ )
95
+
96
+ #hf_hub_download(repo_id="YarvixPA/FLUX.1-Fill-dev-gguf", filename="flux1-fill-dev-Q5_K_S.gguf", local_dir="models/")
97
+
98
+
99
+
100
+ # Load
101
+
102
+ # Model paths
103
+ model_path = './models/iclight_sd15_fc.safetensors'
104
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
105
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
106
+ model_path4 = './checkpoints/config.json'
107
+ model_path5 = './checkpoints/preprocessor_config.json'
108
+ model_path6 = './configs/sam2_hiera_l.yaml'
109
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
110
+
111
+ # Base URL for the repository
112
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
113
+
114
+ # Model URLs
115
+ model_urls = {
116
+ model_path: 'iclight_sd15_fc.safetensors',
117
+ model_path2: 'depth_anything_v2_vits.pth',
118
+ model_path3: 'sam2_hiera_large.pt',
119
+ model_path4: 'config.json',
120
+ model_path5: 'preprocessor_config.json',
121
+ model_path6: 'sam2_hiera_l.yaml',
122
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
123
+ }
124
+
125
+ # Ensure directories exist
126
+ def ensure_directories():
127
+ for path in model_urls.keys():
128
+ os.makedirs(os.path.dirname(path), exist_ok=True)
129
+
130
+ # Download models
131
+ def download_models():
132
+ for local_path, filename in model_urls.items():
133
+ if not os.path.exists(local_path):
134
+ try:
135
+ url = f"{BASE_URL}{filename}"
136
+ print(f"Downloading {filename}")
137
+ download_url_to_file(url, local_path)
138
+ print(f"Successfully downloaded {filename}")
139
+ except Exception as e:
140
+ print(f"Error downloading {filename}: {e}")
141
+
142
+ ensure_directories()
143
+
144
+ download_models()
145
+
146
+
147
+
148
+
149
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
150
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
151
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
152
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
153
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
154
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
155
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
156
+
157
+
158
+
159
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
160
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
161
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
162
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
163
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
164
+
165
+
166
+
167
+ # fill_pipe = FluxFillPipeline.from_single_file(
168
+ # "https://huggingface.co/SporkySporkness/FLUX.1-Fill-dev-GGUF/flux1-fill-dev-fp16-Q5_0-GGUF.gguf",
169
+ # text_encoder= text_encoder,
170
+ # text_encoder_2 = t5_path,
171
+ # ignore_mismatched_sizes=True,
172
+ # low_cpu_mem_usage=False,
173
+ # torch_dtype=torch.bfloat16
174
+ # ).to("cuda")
175
+
176
+ from diffusers import FluxTransformer2DModel, FluxFillPipeline, GGUFQuantizationConfig
177
+ from transformers import T5EncoderModel
178
+ import torch
179
+
180
+ # transformer = FluxTransformer2DModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
181
+ # text_encoder_2 = T5EncoderModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="text_encoder_2", torch_dtype=torch.bfloat16).to("cuda")
182
+ # fill_pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16).to("cuda")
183
+
184
+
185
+ ckpt_path = (
186
+ "https://huggingface.co/SporkySporkness/FLUX.1-Fill-dev-GGUF/flux1-fill-dev-fp16-Q5_0-GGUF.gguf"
187
+ )
188
+
189
+ transformer = FluxTransformer2DModel.from_single_file(
190
+ ckpt_path,
191
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
192
+ torch_dtype=torch.bfloat16,
193
+ )
194
+
195
+ fill_pipe = FluxFillPipeline.from_pretrained(
196
+ "black-forest-labs/FLUX.1-Fill-dev",
197
+ transformer=transformer,
198
+ generator=torch.manual_seed(0),
199
+ torch_dtype=torch.bfloat16,
200
+ )
201
+
202
+ fill_pipe.enable_model_cpu_offload()
203
+
204
+
205
+ try:
206
+ import xformers
207
+ import xformers.ops
208
+ XFORMERS_AVAILABLE = True
209
+ print("xformers is available - Using memory efficient attention")
210
+ except ImportError:
211
+ XFORMERS_AVAILABLE = False
212
+ print("xformers not available - Using default attention")
213
+
214
+ # Memory optimizations for RTX 2070
215
+ torch.backends.cudnn.benchmark = True
216
+ if torch.cuda.is_available():
217
+ torch.backends.cuda.matmul.allow_tf32 = True
218
+ torch.backends.cudnn.allow_tf32 = True
219
+ # Set a smaller attention slice size for RTX 2070
220
+ torch.backends.cuda.max_split_size_mb = 512
221
+ device = torch.device('cuda')
222
+ else:
223
+ device = torch.device('cpu')
224
+
225
+
226
+ rmbg = AutoModelForImageSegmentation.from_pretrained(
227
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
228
+ )
229
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
230
+
231
+
232
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
233
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
234
+ model = model.to(device)
235
+ model.eval()
236
+
237
+
238
+ with torch.no_grad():
239
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
240
+ new_conv_in.weight.zero_()
241
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
242
+ new_conv_in.bias = unet.conv_in.bias
243
+ unet.conv_in = new_conv_in
244
+
245
+
246
+ unet_original_forward = unet.forward
247
+
248
+
249
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
250
+ if alignment in ("Left", "Right") and source_width >= target_width:
251
+ return False
252
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
253
+ return False
254
+ return True
255
+
256
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
257
+ target_size = (width, height)
258
+
259
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
260
+ new_width = int(image.width * scale_factor)
261
+ new_height = int(image.height * scale_factor)
262
+
263
+ source = image.resize((new_width, new_height), Image.LANCZOS)
264
+
265
+ if resize_option == "Full":
266
+ resize_percentage = 100
267
+ elif resize_option == "75%":
268
+ resize_percentage = 75
269
+ elif resize_option == "50%":
270
+ resize_percentage = 50
271
+ elif resize_option == "33%":
272
+ resize_percentage = 33
273
+ elif resize_option == "25%":
274
+ resize_percentage = 25
275
+ else: # Custom
276
+ resize_percentage = custom_resize_percentage
277
+
278
+ # Calculate new dimensions based on percentage
279
+ resize_factor = resize_percentage / 100
280
+ new_width = int(source.width * resize_factor)
281
+ new_height = int(source.height * resize_factor)
282
+
283
+ # Ensure minimum size of 64 pixels
284
+ new_width = max(new_width, 64)
285
+ new_height = max(new_height, 64)
286
+
287
+ # Resize the image
288
+ source = source.resize((new_width, new_height), Image.LANCZOS)
289
+
290
+ # Calculate the overlap in pixels based on the percentage
291
+ overlap_x = int(new_width * (overlap_percentage / 100))
292
+ overlap_y = int(new_height * (overlap_percentage / 100))
293
+
294
+ # Ensure minimum overlap of 1 pixel
295
+ overlap_x = max(overlap_x, 1)
296
+ overlap_y = max(overlap_y, 1)
297
+
298
+ # Calculate margins based on alignment
299
+ if alignment == "Middle":
300
+ margin_x = (target_size[0] - new_width) // 2
301
+ margin_y = (target_size[1] - new_height) // 2
302
+ elif alignment == "Left":
303
+ margin_x = 0
304
+ margin_y = (target_size[1] - new_height) // 2
305
+ elif alignment == "Right":
306
+ margin_x = target_size[0] - new_width
307
+ margin_y = (target_size[1] - new_height) // 2
308
+ elif alignment == "Top":
309
+ margin_x = (target_size[0] - new_width) // 2
310
+ margin_y = 0
311
+ elif alignment == "Bottom":
312
+ margin_x = (target_size[0] - new_width) // 2
313
+ margin_y = target_size[1] - new_height
314
+
315
+ # Adjust margins to eliminate gaps
316
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
317
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
318
+
319
+ # Create a new background image and paste the resized source image
320
+ background = Image.new('RGB', target_size, (255, 255, 255))
321
+ background.paste(source, (margin_x, margin_y))
322
+
323
+ # Create the mask
324
+ mask = Image.new('L', target_size, 255)
325
+ mask_draw = ImageDraw.Draw(mask)
326
+
327
+ # Calculate overlap areas
328
+ white_gaps_patch = 2
329
+
330
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
331
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
332
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
333
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
334
+
335
+ if alignment == "Left":
336
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
337
+ elif alignment == "Right":
338
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
339
+ elif alignment == "Top":
340
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
341
+ elif alignment == "Bottom":
342
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
343
+
344
+ # Draw the mask
345
+ mask_draw.rectangle([
346
+ (left_overlap, top_overlap),
347
+ (right_overlap, bottom_overlap)
348
+ ], fill=0)
349
+
350
+ return background, mask
351
+
352
+ @spaces.GPU(duration=60)
353
+ @torch.inference_mode()
354
+ def inpaint(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
355
+ clear_memory()
356
+
357
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
358
+
359
+ if not can_expand(background.width, background.height, width, height, alignment):
360
+ alignment = "Middle"
361
+
362
+ cnet_image = background.copy()
363
+ cnet_image.paste(0, (0, 0), mask)
364
+
365
+ final_prompt = prompt_input
366
+
367
+ #generator = torch.Generator(device="cuda").manual_seed(42)
368
+
369
+ result = fill_pipe(
370
+ prompt=final_prompt,
371
+ height=height,
372
+ width=width,
373
+ image=cnet_image,
374
+ mask_image=mask,
375
+ num_inference_steps=num_inference_steps,
376
+ guidance_scale=30,
377
+ ).images[0]
378
+
379
+ result = result.convert("RGBA")
380
+ cnet_image.paste(result, (0, 0), mask)
381
+
382
+ return cnet_image #, background
383
+
384
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
385
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
386
+
387
+ preview = background.copy().convert('RGBA')
388
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64))
389
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
390
+ red_mask.paste(red_overlay, (0, 0), mask)
391
+ preview = Image.alpha_composite(preview, red_mask)
392
+
393
+ return preview
394
+
395
+ def clear_result():
396
+ return gr.update(value=None)
397
+
398
+ def preload_presets(target_ratio, ui_width, ui_height):
399
+ if target_ratio == "9:16":
400
+ return 720, 1280, gr.update()
401
+ elif target_ratio == "16:9":
402
+ return 1280, 720, gr.update()
403
+ elif target_ratio == "1:1":
404
+ return 1024, 1024, gr.update()
405
+ elif target_ratio == "Custom":
406
+ return ui_width, ui_height, gr.update(open=True)
407
+
408
+ def select_the_right_preset(user_width, user_height):
409
+ if user_width == 720 and user_height == 1280:
410
+ return "9:16"
411
+ elif user_width == 1280 and user_height == 720:
412
+ return "16:9"
413
+ elif user_width == 1024 and user_height == 1024:
414
+ return "1:1"
415
+ else:
416
+ return "Custom"
417
+
418
+ def toggle_custom_resize_slider(resize_option):
419
+ return gr.update(visible=(resize_option == "Custom"))
420
+
421
+ def update_history(new_image, history):
422
+ if history is None:
423
+ history = []
424
+ history.insert(0, new_image)
425
+ return history
426
+
427
+
428
+ def enable_efficient_attention():
429
+ if XFORMERS_AVAILABLE:
430
+ try:
431
+ # RTX 2070 specific settings
432
+ unet.set_use_memory_efficient_attention_xformers(True)
433
+ vae.set_use_memory_efficient_attention_xformers(True)
434
+ print("Enabled xformers memory efficient attention")
435
+ except Exception as e:
436
+ print(f"Xformers error: {e}")
437
+ print("Falling back to sliced attention")
438
+ # Use sliced attention for RTX 2070
439
+ # unet.set_attention_slice_size(4)
440
+ # vae.set_attention_slice_size(4)
441
+ unet.set_attn_processor(AttnProcessor2_0())
442
+ vae.set_attn_processor(AttnProcessor2_0())
443
+ else:
444
+ # Fallback for when xformers is not available
445
+ print("Using sliced attention")
446
+ # unet.set_attention_slice_size(4)
447
+ # vae.set_attention_slice_size(4)
448
+ unet.set_attn_processor(AttnProcessor2_0())
449
+ vae.set_attn_processor(AttnProcessor2_0())
450
+
451
+ # Add memory clearing function
452
+ def clear_memory():
453
+ if torch.cuda.is_available():
454
+ torch.cuda.empty_cache()
455
+ torch.cuda.synchronize()
456
+
457
+ # Enable efficient attention
458
+ enable_efficient_attention()
459
+
460
+
461
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
462
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
463
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
464
+ new_sample = torch.cat([sample, c_concat], dim=1)
465
+ kwargs['cross_attention_kwargs'] = {}
466
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
467
+
468
+
469
+ unet.forward = hooked_unet_forward
470
+
471
+
472
+ sd_offset = sf.load_file(model_path)
473
+ sd_origin = unet.state_dict()
474
+ keys = sd_origin.keys()
475
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
476
+ unet.load_state_dict(sd_merged, strict=True)
477
+ del sd_offset, sd_origin, sd_merged, keys
478
+
479
+
480
+ # Device and dtype setup
481
+ device = torch.device('cuda')
482
+ #dtype = torch.float16 # RTX 2070 works well with float16
483
+ dtype = torch.bfloat16
484
+
485
+
486
+ pipe = prepare_pipeline(
487
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
488
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
489
+ unet_model=None,
490
+ lora_model=None,
491
+ adapter_path="huanngzh/mv-adapter",
492
+ scheduler=None,
493
+ num_views=NUM_VIEWS,
494
+ device=device,
495
+ dtype=dtype,
496
+ )
497
+
498
+
499
+ # Move models to device with consistent dtype
500
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
501
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
502
+ unet = unet.to(device=device, dtype=dtype)
503
+ #rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
504
+ rmbg = rmbg.to(device)
505
+
506
+ ddim_scheduler = DDIMScheduler(
507
+ num_train_timesteps=1000,
508
+ beta_start=0.00085,
509
+ beta_end=0.012,
510
+ beta_schedule="scaled_linear",
511
+ clip_sample=False,
512
+ set_alpha_to_one=False,
513
+ steps_offset=1,
514
+ )
515
+
516
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
517
+ num_train_timesteps=1000,
518
+ beta_start=0.00085,
519
+ beta_end=0.012,
520
+ steps_offset=1
521
+ )
522
+
523
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
524
+ num_train_timesteps=1000,
525
+ beta_start=0.00085,
526
+ beta_end=0.012,
527
+ algorithm_type="sde-dpmsolver++",
528
+ use_karras_sigmas=True,
529
+ steps_offset=1
530
+ )
531
+
532
+ # Pipelines
533
+
534
+
535
+ t2i_pipe = StableDiffusionPipeline(
536
+ vae=vae,
537
+ text_encoder=text_encoder,
538
+ tokenizer=tokenizer,
539
+ unet=unet,
540
+ scheduler=dpmpp_2m_sde_karras_scheduler,
541
+ safety_checker=None,
542
+ requires_safety_checker=False,
543
+ feature_extractor=None,
544
+ image_encoder=None
545
+ )
546
+
547
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
548
+ vae=vae,
549
+ text_encoder=text_encoder,
550
+ tokenizer=tokenizer,
551
+ unet=unet,
552
+ scheduler=dpmpp_2m_sde_karras_scheduler,
553
+ safety_checker=None,
554
+ requires_safety_checker=False,
555
+ feature_extractor=None,
556
+ image_encoder=None
557
+ )
558
+
559
+
560
+ @torch.inference_mode()
561
+ def encode_prompt_inner(txt: str):
562
+ max_length = tokenizer.model_max_length
563
+ chunk_length = tokenizer.model_max_length - 2
564
+ id_start = tokenizer.bos_token_id
565
+ id_end = tokenizer.eos_token_id
566
+ id_pad = id_end
567
+
568
+ def pad(x, p, i):
569
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
570
+
571
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
572
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
573
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
574
+
575
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
576
+ conds = text_encoder(token_ids).last_hidden_state
577
+
578
+ return conds
579
+
580
+
581
+ @torch.inference_mode()
582
+ def encode_prompt_pair(positive_prompt, negative_prompt):
583
+ c = encode_prompt_inner(positive_prompt)
584
+ uc = encode_prompt_inner(negative_prompt)
585
+
586
+ c_len = float(len(c))
587
+ uc_len = float(len(uc))
588
+ max_count = max(c_len, uc_len)
589
+ c_repeat = int(math.ceil(max_count / c_len))
590
+ uc_repeat = int(math.ceil(max_count / uc_len))
591
+ max_chunk = max(len(c), len(uc))
592
+
593
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
594
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
595
+
596
+ c = torch.cat([p[None, ...] for p in c], dim=1)
597
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
598
+
599
+ return c, uc
600
+
601
+
602
+ @spaces.GPU(duration=60)
603
+ @torch.inference_mode()
604
+ def infer(
605
+ prompt,
606
+ image, # This is already RGBA with background removed
607
+ do_rembg=True,
608
+ seed=42,
609
+ randomize_seed=False,
610
+ guidance_scale=3.0,
611
+ num_inference_steps=30,
612
+ reference_conditioning_scale=1.0,
613
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
614
+ progress=gr.Progress(track_tqdm=True),
615
+ ):
616
+ clear_memory()
617
+
618
+ # Convert input to PIL if needed
619
+ if isinstance(image, np.ndarray):
620
+ if image.shape[-1] == 4: # RGBA
621
+ image = Image.fromarray(image, 'RGBA')
622
+ else: # RGB
623
+ image = Image.fromarray(image, 'RGB')
624
+
625
+ #logging.info(f"Converted to PIL Image mode: {image.mode}")
626
+
627
+ # No need for remove_bg_fn since image is already processed
628
+ remove_bg_fn = None
629
+
630
+ if randomize_seed:
631
+ seed = random.randint(0, MAX_SEED)
632
+
633
+ images, preprocessed_image = run_pipeline(
634
+ pipe,
635
+ num_views=NUM_VIEWS,
636
+ text=prompt,
637
+ image=image,
638
+ height=HEIGHT,
639
+ width=WIDTH,
640
+ num_inference_steps=num_inference_steps,
641
+ guidance_scale=guidance_scale,
642
+ seed=seed,
643
+ remove_bg_fn=remove_bg_fn, # Set to None since preprocessing is done
644
+ reference_conditioning_scale=reference_conditioning_scale,
645
+ negative_prompt=negative_prompt,
646
+ device=device,
647
+ )
648
+
649
+ # logging.info(f"Output images shape: {[img.shape for img in images]}")
650
+ # logging.info(f"Preprocessed image shape: {preprocessed_image.shape if preprocessed_image is not None else None}")
651
+ return images
652
+
653
+
654
+ @spaces.GPU(duration=60)
655
+ @torch.inference_mode()
656
+ def pytorch2numpy(imgs, quant=True):
657
+ results = []
658
+ for x in imgs:
659
+ y = x.movedim(0, -1)
660
+
661
+ if quant:
662
+ y = y * 127.5 + 127.5
663
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
664
+ else:
665
+ y = y * 0.5 + 0.5
666
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
667
+
668
+ results.append(y)
669
+ return results
670
+
671
+ @spaces.GPU(duration=60)
672
+ @torch.inference_mode()
673
+ def numpy2pytorch(imgs):
674
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
675
+ h = h.movedim(-1, 1)
676
+ return h.to(device=device, dtype=dtype)
677
+
678
+
679
+ def resize_and_center_crop(image, target_width, target_height):
680
+ pil_image = Image.fromarray(image)
681
+ original_width, original_height = pil_image.size
682
+ scale_factor = max(target_width / original_width, target_height / original_height)
683
+ resized_width = int(round(original_width * scale_factor))
684
+ resized_height = int(round(original_height * scale_factor))
685
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
686
+ left = (resized_width - target_width) / 2
687
+ top = (resized_height - target_height) / 2
688
+ right = (resized_width + target_width) / 2
689
+ bottom = (resized_height + target_height) / 2
690
+ cropped_image = resized_image.crop((left, top, right, bottom))
691
+ return np.array(cropped_image)
692
+
693
+
694
+ def resize_without_crop(image, target_width, target_height):
695
+ pil_image = Image.fromarray(image)
696
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
697
+ return np.array(resized_image)
698
+
699
+
700
+ @spaces.GPU
701
+ @torch.inference_mode()
702
+ def run_rmbg(image):
703
+ clear_memory()
704
+ image_size = image.size
705
+ input_images = transform_image(image).unsqueeze(0).to(device, dtype=torch.float32)
706
+ # Prediction
707
+ with torch.no_grad():
708
+ preds = rmbg(input_images)[-1].sigmoid().cpu()
709
+ pred = preds[0].squeeze()
710
+ pred_pil = transforms.ToPILImage()(pred)
711
+ mask = pred_pil.resize(image_size)
712
+ image.putalpha(mask)
713
+ return image
714
+
715
+
716
+
717
+ def preprocess_image(image: Image.Image, height=768, width=768):
718
+ image = np.array(image)
719
+ alpha = image[..., 3] > 0
720
+ H, W = alpha.shape
721
+ # get the bounding box of alpha
722
+ y, x = np.where(alpha)
723
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
724
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
725
+ image_center = image[y0:y1, x0:x1]
726
+ # resize the longer side to# resize the longer side to H * 0.9
727
+ H, W, _ = image_center.shape
728
+ if H > W:
729
+ W = int(W * (height * 0.9) / H)
730
+ H = int(height * 0.9)
731
+ else:
732
+ H = int(H * (width * 0.9) / W)
733
+ W = int(width * 0.9)
734
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
735
+ # pad to H, W
736
+ start_h = (height - H) // 2
737
+ start_w = (width - W) // 2
738
+ image = np.zeros((height, width, 4), dtype=np.uint8)
739
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
740
+ image = image.astype(np.float32) / 255.0
741
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
742
+ image = (image * 255).clip(0, 255).astype(np.uint8)
743
+ image = Image.fromarray(image)
744
+ return image
745
+
746
+
747
+ @spaces.GPU(duration=60)
748
+ @torch.inference_mode()
749
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
750
+ clear_memory()
751
+
752
+ # Get input dimensions
753
+ input_height, input_width = input_fg.shape[:2]
754
+
755
+ bg_source = BGSource(bg_source)
756
+
757
+ if bg_source == BGSource.UPLOAD:
758
+ pass
759
+ elif bg_source == BGSource.UPLOAD_FLIP:
760
+ input_bg = np.fliplr(input_bg)
761
+ if bg_source == BGSource.GREY:
762
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
763
+ elif bg_source == BGSource.LEFT:
764
+ gradient = np.linspace(255, 0, input_width)
765
+ image = np.tile(gradient, (input_height, 1))
766
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
767
+ elif bg_source == BGSource.RIGHT:
768
+ gradient = np.linspace(0, 255, input_width)
769
+ image = np.tile(gradient, (input_height, 1))
770
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
771
+ elif bg_source == BGSource.TOP:
772
+ gradient = np.linspace(255, 0, input_height)[:, None]
773
+ image = np.tile(gradient, (1, input_width))
774
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
775
+ elif bg_source == BGSource.BOTTOM:
776
+ gradient = np.linspace(0, 255, input_height)[:, None]
777
+ image = np.tile(gradient, (1, input_width))
778
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
779
+ else:
780
+ raise 'Wrong initial latent!'
781
+
782
+ rng = torch.Generator(device=device).manual_seed(int(seed))
783
+
784
+ # Use input dimensions directly
785
+ fg = resize_without_crop(input_fg, input_width, input_height)
786
+
787
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
788
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
789
+
790
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
791
+
792
+ if input_bg is None:
793
+ latents = t2i_pipe(
794
+ prompt_embeds=conds,
795
+ negative_prompt_embeds=unconds,
796
+ width=input_width,
797
+ height=input_height,
798
+ num_inference_steps=steps,
799
+ num_images_per_prompt=num_samples,
800
+ generator=rng,
801
+ output_type='latent',
802
+ guidance_scale=cfg,
803
+ cross_attention_kwargs={'concat_conds': concat_conds},
804
+ ).images.to(vae.dtype) / vae.config.scaling_factor
805
+ else:
806
+ bg = resize_without_crop(input_bg, input_width, input_height)
807
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
808
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
809
+ latents = i2i_pipe(
810
+ image=bg_latent,
811
+ strength=lowres_denoise,
812
+ prompt_embeds=conds,
813
+ negative_prompt_embeds=unconds,
814
+ width=input_width,
815
+ height=input_height,
816
+ num_inference_steps=int(round(steps / lowres_denoise)),
817
+ num_images_per_prompt=num_samples,
818
+ generator=rng,
819
+ output_type='latent',
820
+ guidance_scale=cfg,
821
+ cross_attention_kwargs={'concat_conds': concat_conds},
822
+ ).images.to(vae.dtype) / vae.config.scaling_factor
823
+
824
+ pixels = vae.decode(latents).sample
825
+ pixels = pytorch2numpy(pixels)
826
+ pixels = [resize_without_crop(
827
+ image=p,
828
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
829
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
830
+ for p in pixels]
831
+
832
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
833
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
834
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
835
+
836
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
837
+
838
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
839
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
840
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
841
+
842
+ latents = i2i_pipe(
843
+ image=latents,
844
+ strength=highres_denoise,
845
+ prompt_embeds=conds,
846
+ negative_prompt_embeds=unconds,
847
+ width=highres_width,
848
+ height=highres_height,
849
+ num_inference_steps=int(round(steps / highres_denoise)),
850
+ num_images_per_prompt=num_samples,
851
+ generator=rng,
852
+ output_type='latent',
853
+ guidance_scale=cfg,
854
+ cross_attention_kwargs={'concat_conds': concat_conds},
855
+ ).images.to(vae.dtype) / vae.config.scaling_factor
856
+
857
+ pixels = vae.decode(latents).sample
858
+ pixels = pytorch2numpy(pixels)
859
+
860
+ # Resize back to input dimensions
861
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
862
+ pixels = np.stack(pixels)
863
+
864
+ return pixels
865
+
866
+ def extract_foreground(image):
867
+ if image is None:
868
+ return None, gr.update(visible=True), gr.update(visible=True)
869
+ clear_memory()
870
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
871
+ #result, rgba = run_rmbg(image)
872
+ result = run_rmbg(image)
873
+ result = preprocess_image(result)
874
+ #logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
875
+ #logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
876
+ return result, gr.update(visible=True), gr.update(visible=True)
877
+
878
+ def update_extracted_fg_height(selected_image: gr.SelectData):
879
+ if selected_image:
880
+ # Get the height of the selected image
881
+ height = selected_image.value['image']['shape'][0] # Assuming the image is in numpy format
882
+ return gr.update(height=height) # Update the height of extracted_fg
883
+ return gr.update(height=480) # Default height if no image is selected
884
+
885
+
886
+
887
+ @torch.inference_mode()
888
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
889
+ # Convert input foreground from PIL to NumPy array if it's in PIL format
890
+ if isinstance(input_fg, Image.Image):
891
+ input_fg = np.array(input_fg)
892
+ logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
893
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
894
+ logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
895
+ return results
896
+
897
+ quick_prompts = [
898
+ 'sunshine from window',
899
+ 'golden time',
900
+ 'natural lighting',
901
+ 'warm atmosphere, at home, bedroom',
902
+ 'shadow from window',
903
+ 'soft studio lighting',
904
+ 'home atmosphere, cozy bedroom illumination',
905
+ ]
906
+ quick_prompts = [[x] for x in quick_prompts]
907
+
908
+ quick_subjects = [
909
+ 'modern sofa, high quality leather',
910
+ 'elegant dining table, polished wood',
911
+ 'luxurious bed, premium mattress',
912
+ 'minimalist office desk, clean design',
913
+ 'vintage wooden cabinet, antique finish',
914
+ ]
915
+ quick_subjects = [[x] for x in quick_subjects]
916
+
917
+ class BGSource(Enum):
918
+ UPLOAD = "Use Background Image"
919
+ UPLOAD_FLIP = "Use Flipped Background Image"
920
+ NONE = "None"
921
+ LEFT = "Left Light"
922
+ RIGHT = "Right Light"
923
+ TOP = "Top Light"
924
+ BOTTOM = "Bottom Light"
925
+ GREY = "Ambient"
926
+
927
+ # Add save function
928
+ def save_images(images, prefix="relight"):
929
+ # Create output directory if it doesn't exist
930
+ output_dir = Path("outputs")
931
+ output_dir.mkdir(exist_ok=True)
932
+
933
+ # Create timestamp for unique filenames
934
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
935
+
936
+ saved_paths = []
937
+ for i, img in enumerate(images):
938
+ if isinstance(img, np.ndarray):
939
+ # Convert to PIL Image if numpy array
940
+ img = Image.fromarray(img)
941
+
942
+ # Create filename with timestamp
943
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
944
+ filepath = output_dir / filename
945
+
946
+ # Save image
947
+ img.save(filepath)
948
+
949
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
950
+ return saved_paths
951
+
952
+ class MaskMover:
953
+ def __init__(self):
954
+ self.extracted_fg = None
955
+ self.original_fg = None # Store original foreground
956
+
957
+ def set_extracted_fg(self, fg_image):
958
+ """Store the extracted foreground with alpha channel"""
959
+ if isinstance(fg_image, np.ndarray):
960
+ self.extracted_fg = fg_image.copy()
961
+ self.original_fg = fg_image.copy()
962
+ else:
963
+ self.extracted_fg = np.array(fg_image)
964
+ self.original_fg = np.array(fg_image)
965
+ return self.extracted_fg
966
+
967
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
968
+ """Create composite with foreground at specified position"""
969
+ if self.original_fg is None or background is None:
970
+ return background
971
+
972
+ # Convert inputs to PIL Images
973
+ if isinstance(background, np.ndarray):
974
+ bg = Image.fromarray(background).convert('RGBA')
975
+ else:
976
+ bg = background.convert('RGBA')
977
+
978
+ if isinstance(self.original_fg, np.ndarray):
979
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
980
+ else:
981
+ fg = self.original_fg.convert('RGBA')
982
+
983
+ # Scale the foreground size
984
+ new_width = int(fg.width * scale)
985
+ new_height = int(fg.height * scale)
986
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
987
+
988
+ # Center the scaled foreground at the position
989
+ x = int(x_pos - new_width / 2)
990
+ y = int(y_pos - new_height / 2)
991
+
992
+ # Create composite
993
+ result = bg.copy()
994
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
995
+
996
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
997
+
998
+ def get_depth(image):
999
+ if image is None:
1000
+ return None
1001
+ # Convert from PIL/gradio format to cv2
1002
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1003
+ # Get depth map
1004
+ depth = model.infer_image(raw_img) # HxW raw depth map
1005
+ # Normalize depth for visualization
1006
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
1007
+ # Convert to RGB for display
1008
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
1009
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
1010
+ return Image.fromarray(depth_colored)
1011
+
1012
+ from PIL import Image
1013
+
1014
+ def compress_image(image):
1015
+ # Convert Gradio image (numpy array) to PIL Image
1016
+ img = Image.fromarray(image)
1017
+
1018
+ # Resize image if dimensions are too large
1019
+ max_size = 1024 # Maximum dimension size
1020
+ if img.width > max_size or img.height > max_size:
1021
+ ratio = min(max_size/img.width, max_size/img.height)
1022
+ new_size = (int(img.width * ratio), int(img.height * ratio))
1023
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
1024
+
1025
+ quality = 95 # Start with high quality
1026
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
1027
+
1028
+ # Check file size and adjust quality if necessary
1029
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
1030
+ quality -= 5 # Decrease quality
1031
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
1032
+ if quality < 20: # Prevent quality from going too low
1033
+ break
1034
+
1035
+ # Convert back to numpy array for Gradio
1036
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
1037
+ return compressed_img
1038
+
1039
+ def use_orientation(selected_image:gr.SelectData):
1040
+ return selected_image.value['image']['path']
1041
+
1042
+
1043
+ @spaces.GPU(duration=60)
1044
+ @torch.inference_mode
1045
+ def process_image(input_image, input_text):
1046
+ """Main processing function for the Gradio interface"""
1047
+
1048
+ if isinstance(input_image, Image.Image):
1049
+ input_image = np.array(input_image)
1050
+
1051
+
1052
+ clear_memory()
1053
+
1054
+ # Initialize configs
1055
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
1056
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
1057
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
1058
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1059
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
1060
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1061
+
1062
+ HEIGHT = 768
1063
+ WIDTH = 768
1064
+
1065
+ # Initialize DDS client
1066
+ config = Config(API_TOKEN)
1067
+ client = Client(config)
1068
+
1069
+ # Process classes from text prompt
1070
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1071
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1072
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1073
+
1074
+ # Save input image to temp file and get URL
1075
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
1076
+ cv2.imwrite(tmpfile.name, input_image)
1077
+ image_url = client.upload_file(tmpfile.name)
1078
+ os.remove(tmpfile.name)
1079
+
1080
+ # Process detection results
1081
+ input_boxes = []
1082
+ masks = []
1083
+ confidences = []
1084
+ class_names = []
1085
+ class_ids = []
1086
+
1087
+ if len(input_text) == 0:
1088
+ task = DinoxTask(
1089
+ image_url=image_url,
1090
+ prompts=[TextPrompt(text="<prompt_free>")],
1091
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1092
+ )
1093
+
1094
+ client.run_task(task)
1095
+ predictions = task.result.objects
1096
+ classes = [pred.category for pred in predictions]
1097
+ classes = list(set(classes))
1098
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1099
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1100
+
1101
+ for idx, obj in enumerate(predictions):
1102
+ input_boxes.append(obj.bbox)
1103
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1104
+ confidences.append(obj.score)
1105
+ cls_name = obj.category.lower().strip()
1106
+ class_names.append(cls_name)
1107
+ class_ids.append(class_name_to_id[cls_name])
1108
+
1109
+ boxes = np.array(input_boxes)
1110
+ masks = np.array(masks)
1111
+ class_ids = np.array(class_ids)
1112
+ labels = [
1113
+ f"{class_name} {confidence:.2f}"
1114
+ for class_name, confidence
1115
+ in zip(class_names, confidences)
1116
+ ]
1117
+ detections = sv.Detections(
1118
+ xyxy=boxes,
1119
+ mask=masks.astype(bool),
1120
+ class_id=class_ids
1121
+ )
1122
+
1123
+ box_annotator = sv.BoxAnnotator()
1124
+ label_annotator = sv.LabelAnnotator()
1125
+ mask_annotator = sv.MaskAnnotator()
1126
+
1127
+ annotated_frame = input_image.copy()
1128
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1129
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1130
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1131
+
1132
+ # Create transparent mask for first detected object
1133
+ if len(detections) > 0:
1134
+ # Get first mask
1135
+ first_mask = detections.mask[0]
1136
+
1137
+ # Get original RGB image
1138
+ img = input_image.copy()
1139
+ H, W, C = img.shape
1140
+
1141
+ # Create RGBA image with default 255 alpha
1142
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1143
+ alpha[~first_mask] = 0 # 128 # for semi-transparency background
1144
+ alpha[first_mask] = 255 # Make the foreground opaque
1145
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
1146
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1147
+
1148
+ # get the bounding box of alpha
1149
+ y, x = np.where(alpha > 0)
1150
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1151
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1152
+
1153
+ image_center = rgba[y0:y1, x0:x1]
1154
+ # resize the longer side to H * 0.9
1155
+ H, W, _ = image_center.shape
1156
+ if H > W:
1157
+ W = int(W * (HEIGHT * 0.9) / H)
1158
+ H = int(HEIGHT * 0.9)
1159
+ else:
1160
+ H = int(H * (WIDTH * 0.9) / W)
1161
+ W = int(WIDTH * 0.9)
1162
+
1163
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1164
+ # pad to H, W
1165
+ start_h = (HEIGHT - H) // 2
1166
+ start_w = (WIDTH - W) // 2
1167
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1168
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1169
+ image = image.astype(np.float32) / 255.0
1170
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1171
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1172
+ image = Image.fromarray(image)
1173
+
1174
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1175
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1176
+ else:
1177
+ # Run DINO-X detection
1178
+ task = DinoxTask(
1179
+ image_url=image_url,
1180
+ prompts=[TextPrompt(text=input_text)],
1181
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1182
+ )
1183
+
1184
+ client.run_task(task)
1185
+ result = task.result
1186
+ objects = result.objects
1187
+
1188
+ predictions = task.result.objects
1189
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1190
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1191
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1192
+
1193
+ boxes = []
1194
+ masks = []
1195
+ confidences = []
1196
+ class_names = []
1197
+ class_ids = []
1198
+
1199
+ for idx, obj in enumerate(predictions):
1200
+ boxes.append(obj.bbox)
1201
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1202
+ confidences.append(obj.score)
1203
+ cls_name = obj.category.lower().strip()
1204
+ class_names.append(cls_name)
1205
+ class_ids.append(class_name_to_id[cls_name])
1206
+
1207
+ boxes = np.array(boxes)
1208
+ masks = np.array(masks)
1209
+ class_ids = np.array(class_ids)
1210
+ labels = [
1211
+ f"{class_name} {confidence:.2f}"
1212
+ for class_name, confidence
1213
+ in zip(class_names, confidences)
1214
+ ]
1215
+
1216
+ detections = sv.Detections(
1217
+ xyxy=boxes,
1218
+ mask=masks.astype(bool),
1219
+ class_id=class_ids,
1220
+ )
1221
+
1222
+ box_annotator = sv.BoxAnnotator()
1223
+ label_annotator = sv.LabelAnnotator()
1224
+ mask_annotator = sv.MaskAnnotator()
1225
+
1226
+ annotated_frame = input_image.copy()
1227
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1228
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1229
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1230
+
1231
+ # Create transparent mask for first detected object
1232
+ if len(detections) > 0:
1233
+ # Get first mask
1234
+ first_mask = detections.mask[0]
1235
+
1236
+ # Get original RGB image
1237
+ img = input_image.copy()
1238
+ H, W, C = img.shape
1239
+
1240
+ # Create RGBA image with default 255 alpha
1241
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1242
+ alpha[~first_mask] = 0 # 128 for semi-transparency background
1243
+ alpha[first_mask] = 255 # Make the foreground opaque
1244
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
1245
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1246
+ # get the bounding box of alpha
1247
+ y, x = np.where(alpha > 0)
1248
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1249
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1250
+
1251
+ image_center = rgba[y0:y1, x0:x1]
1252
+ # resize the longer side to H * 0.9
1253
+ H, W, _ = image_center.shape
1254
+ if H > W:
1255
+ W = int(W * (HEIGHT * 0.9) / H)
1256
+ H = int(HEIGHT * 0.9)
1257
+ else:
1258
+ H = int(H * (WIDTH * 0.9) / W)
1259
+ W = int(WIDTH * 0.9)
1260
+
1261
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1262
+ # pad to H, W
1263
+ start_h = (HEIGHT - H) // 2
1264
+ start_w = (WIDTH - W) // 2
1265
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1266
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1267
+ image = image.astype(np.float32) / 255.0
1268
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1269
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1270
+ image = Image.fromarray(image)
1271
+
1272
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1273
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1274
+
1275
+
1276
+
1277
+ # Import all the necessary functions from the original script
1278
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
1279
+ try:
1280
+ return obj[index]
1281
+ except KeyError:
1282
+ return obj["result"][index]
1283
+
1284
+ # Add all the necessary setup functions from the original script
1285
+ def find_path(name: str, path: str = None) -> str:
1286
+ if path is None:
1287
+ path = os.getcwd()
1288
+ if name in os.listdir(path):
1289
+ path_name = os.path.join(path, name)
1290
+ print(f"{name} found: {path_name}")
1291
+ return path_name
1292
+ parent_directory = os.path.dirname(path)
1293
+ if parent_directory == path:
1294
+ return None
1295
+ return find_path(name, parent_directory)
1296
+
1297
+ def add_comfyui_directory_to_sys_path() -> None:
1298
+ comfyui_path = find_path("ComfyUI")
1299
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
1300
+ sys.path.append(comfyui_path)
1301
+ print(f"'{comfyui_path}' added to sys.path")
1302
+
1303
+ def add_extra_model_paths() -> None:
1304
+ try:
1305
+ from main import load_extra_path_config
1306
+ except ImportError:
1307
+ from utils.extra_config import load_extra_path_config
1308
+ extra_model_paths = find_path("extra_model_paths.yaml")
1309
+ if extra_model_paths is not None:
1310
+ load_extra_path_config(extra_model_paths)
1311
+ else:
1312
+ print("Could not find the extra_model_paths config file.")
1313
+
1314
+ # Initialize paths
1315
+ add_comfyui_directory_to_sys_path()
1316
+ add_extra_model_paths()
1317
+
1318
+ def import_custom_nodes() -> None:
1319
+ import asyncio
1320
+ import execution
1321
+ from nodes import init_extra_nodes
1322
+ import server
1323
+ loop = asyncio.new_event_loop()
1324
+ asyncio.set_event_loop(loop)
1325
+ server_instance = server.PromptServer(loop)
1326
+ execution.PromptQueue(server_instance)
1327
+ init_extra_nodes()
1328
+
1329
+ # Import all necessary nodes
1330
+ from nodes import (
1331
+ StyleModelLoader,
1332
+ VAEEncode,
1333
+ NODE_CLASS_MAPPINGS,
1334
+ LoadImage,
1335
+ CLIPVisionLoader,
1336
+ SaveImage,
1337
+ VAELoader,
1338
+ CLIPVisionEncode,
1339
+ DualCLIPLoader,
1340
+ EmptyLatentImage,
1341
+ VAEDecode,
1342
+ UNETLoader,
1343
+ CLIPTextEncode,
1344
+ )
1345
+
1346
+ # Initialize all constant nodes and models in global context
1347
+ import_custom_nodes()
1348
+
1349
+ # Global variables for preloaded models and constants
1350
+ #with torch.inference_mode():
1351
+ # Initialize constants
1352
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
1353
+ CONST_1024 = intconstant.get_value(value=1024)
1354
+
1355
+ # Load CLIP
1356
+ dualcliploader = DualCLIPLoader()
1357
+ CLIP_MODEL = dualcliploader.load_clip(
1358
+ clip_name1="t5/t5xxl_fp16.safetensors",
1359
+ clip_name2="clip_l.safetensors",
1360
+ type="flux",
1361
+ )
1362
+
1363
+ # Load VAE
1364
+ vaeloader = VAELoader()
1365
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
1366
+
1367
+ # Load UNET
1368
+ unetloader = UNETLoader()
1369
+ UNET_MODEL = unetloader.load_unet(
1370
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
1371
+ )
1372
+
1373
+ # Load CLIP Vision
1374
+ clipvisionloader = CLIPVisionLoader()
1375
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
1376
+ clip_name="sigclip_vision_patch14_384.safetensors"
1377
+ )
1378
+
1379
+ # Load Style Model
1380
+ stylemodelloader = StyleModelLoader()
1381
+ STYLE_MODEL = stylemodelloader.load_style_model(
1382
+ style_model_name="flux1-redux-dev.safetensors"
1383
+ )
1384
+
1385
+ # Initialize samplers
1386
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
1387
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
1388
+
1389
+ # Initialize depth model
1390
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
1391
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
1392
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
1393
+ model="depth_anything_v2_vitl_fp32.safetensors"
1394
+ )
1395
+ cliptextencode = CLIPTextEncode()
1396
+ loadimage = LoadImage()
1397
+ vaeencode = VAEEncode()
1398
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
1399
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
1400
+ clipvisionencode = CLIPVisionEncode()
1401
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
1402
+ emptylatentimage = EmptyLatentImage()
1403
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
1404
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
1405
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
1406
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
1407
+ vaedecode = VAEDecode()
1408
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
1409
+ saveimage = SaveImage()
1410
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
1411
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
1412
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
1413
+
1414
+ @spaces.GPU
1415
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
1416
+ """Main generation function that processes inputs and returns the path to the generated image."""
1417
+
1418
+ clear_memory()
1419
+ with torch.inference_mode():
1420
+ # Set up CLIP
1421
+ clip_switch = cr_clip_input_switch.switch(
1422
+ Input=1,
1423
+ clip1=get_value_at_index(CLIP_MODEL, 0),
1424
+ clip2=get_value_at_index(CLIP_MODEL, 0),
1425
+ )
1426
+
1427
+ # Encode text
1428
+ text_encoded = cliptextencode.encode(
1429
+ text=prompt,
1430
+ clip=get_value_at_index(clip_switch, 0),
1431
+ )
1432
+ empty_text = cliptextencode.encode(
1433
+ text="",
1434
+ clip=get_value_at_index(clip_switch, 0),
1435
+ )
1436
+
1437
+ # Process structure image
1438
+ structure_img = loadimage.load_image(image=structure_image)
1439
+
1440
+ # Resize image
1441
+ resized_img = imageresize.execute(
1442
+ width=get_value_at_index(CONST_1024, 0),
1443
+ height=get_value_at_index(CONST_1024, 0),
1444
+ interpolation="bicubic",
1445
+ method="keep proportion",
1446
+ condition="always",
1447
+ multiple_of=16,
1448
+ image=get_value_at_index(structure_img, 0),
1449
+ )
1450
+
1451
+ # Get image size
1452
+ size_info = getimagesizeandcount.getsize(
1453
+ image=get_value_at_index(resized_img, 0)
1454
+ )
1455
+
1456
+ # Encode VAE
1457
+ vae_encoded = vaeencode.encode(
1458
+ pixels=get_value_at_index(size_info, 0),
1459
+ vae=get_value_at_index(VAE_MODEL, 0),
1460
+ )
1461
+
1462
+ # Process depth
1463
+ depth_processed = depthanything_v2.process(
1464
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
1465
+ images=get_value_at_index(size_info, 0),
1466
+ )
1467
+
1468
+ # Apply Flux guidance
1469
+ flux_guided = fluxguidance.append(
1470
+ guidance=depth_strength,
1471
+ conditioning=get_value_at_index(text_encoded, 0),
1472
+ )
1473
+
1474
+ # Process style image
1475
+ style_img = loadimage.load_image(image=style_image)
1476
+
1477
+ # Encode style with CLIP Vision
1478
+ style_encoded = clipvisionencode.encode(
1479
+ crop="center",
1480
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
1481
+ image=get_value_at_index(style_img, 0),
1482
+ )
1483
+
1484
+ # Set up conditioning
1485
+ conditioning = instructpixtopixconditioning.encode(
1486
+ positive=get_value_at_index(flux_guided, 0),
1487
+ negative=get_value_at_index(empty_text, 0),
1488
+ vae=get_value_at_index(VAE_MODEL, 0),
1489
+ pixels=get_value_at_index(depth_processed, 0),
1490
+ )
1491
+
1492
+ # Apply style
1493
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
1494
+ strength=style_strength,
1495
+ conditioning=get_value_at_index(conditioning, 0),
1496
+ style_model=get_value_at_index(STYLE_MODEL, 0),
1497
+ clip_vision_output=get_value_at_index(style_encoded, 0),
1498
+ )
1499
+
1500
+ # Set up empty latent
1501
+ empty_latent = emptylatentimage.generate(
1502
+ width=get_value_at_index(resized_img, 1),
1503
+ height=get_value_at_index(resized_img, 2),
1504
+ batch_size=1,
1505
+ )
1506
+
1507
+ # Set up guidance
1508
+ guided = basicguider.get_guider(
1509
+ model=get_value_at_index(UNET_MODEL, 0),
1510
+ conditioning=get_value_at_index(style_applied, 0),
1511
+ )
1512
+
1513
+ # Set up scheduler
1514
+ schedule = basicscheduler.get_sigmas(
1515
+ scheduler="simple",
1516
+ steps=28,
1517
+ denoise=1,
1518
+ model=get_value_at_index(UNET_MODEL, 0),
1519
+ )
1520
+
1521
+ # Generate random noise
1522
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
1523
+
1524
+ # Sample
1525
+ sampled = samplercustomadvanced.sample(
1526
+ noise=get_value_at_index(noise, 0),
1527
+ guider=get_value_at_index(guided, 0),
1528
+ sampler=get_value_at_index(SAMPLER, 0),
1529
+ sigmas=get_value_at_index(schedule, 0),
1530
+ latent_image=get_value_at_index(empty_latent, 0),
1531
+ )
1532
+
1533
+ # Decode VAE
1534
+ decoded = vaedecode.decode(
1535
+ samples=get_value_at_index(sampled, 0),
1536
+ vae=get_value_at_index(VAE_MODEL, 0),
1537
+ )
1538
+
1539
+ # Save image
1540
+ prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
1541
+
1542
+ saved = saveimage.save_images(
1543
+ filename_prefix=get_value_at_index(prefix, 0),
1544
+ images=get_value_at_index(decoded, 0),
1545
+ )
1546
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
1547
+
1548
+ clear_memory()
1549
+ return saved_path
1550
+
1551
+ # Create Gradio interface
1552
+
1553
+ examples = [
1554
+ ["", "chair_input_1.jpg", "chair_input_2.png", 15, 0.5],
1555
+ ]
1556
+
1557
+ output_image = gr.Image(label="Generated Image")
1558
+
1559
+ with gr.Blocks() as app:
1560
+ with gr.Tab("Relighting"):
1561
+ with gr.Row():
1562
+ gr.Markdown("## Product Placement from Text")
1563
+ with gr.Row():
1564
+ with gr.Column():
1565
+ with gr.Row():
1566
+ input_fg = gr.Image(type="pil", label="Image", height=480)
1567
+ with gr.Row():
1568
+ with gr.Group():
1569
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1570
+ text_prompt = gr.Textbox(
1571
+ label="Text Prompt",
1572
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1573
+ value=""
1574
+ )
1575
+ extract_button = gr.Button(value="Remove Background")
1576
+ with gr.Row():
1577
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1578
+ extracted_fg = gr.Image(type="pil", label="Extracted Foreground", height=480)
1579
+ angles_fg = gr.Image(type="pil", label="Converted Foreground", height=480, visible=False)
1580
+
1581
+
1582
+
1583
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
1584
+ with gr.Group():
1585
+ run_button = gr.Button("Generate alternative angles")
1586
+ orientation_result = gr.Gallery(
1587
+ label="Result",
1588
+ show_label=False,
1589
+ columns=[3],
1590
+ rows=[2],
1591
+ object_fit="fill",
1592
+ height="auto",
1593
+ allow_preview=False,
1594
+ )
1595
+
1596
+ if orientation_result:
1597
+ orientation_result.select(use_orientation, inputs=None, outputs=extracted_fg)
1598
+
1599
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result', type='pil')
1600
+
1601
+
1602
+ with gr.Column():
1603
+
1604
+ with gr.Row():
1605
+ with gr.Column(4):
1606
+ result_gallery = gr.Gallery(height=832, label='Outputs', object_fit='contain', selected_index=0)
1607
+ if result_gallery:
1608
+ result_gallery.select(use_orientation, inputs=None, outputs=dummy_image_for_outputs)
1609
+ with gr.Column(1):
1610
+ with gr.Group():
1611
+ gr.Markdown("Outpaint")
1612
+ with gr.Row():
1613
+ with gr.Column(scale=2):
1614
+ prompt_fill = gr.Textbox(label="Prompt (Optional)")
1615
+ with gr.Column(scale=1):
1616
+ fill_button = gr.Button("Generate")
1617
+ target_ratio = gr.Radio(
1618
+ label="Image Ratio",
1619
+ choices=["9:16", "16:9", "1:1", "Custom"],
1620
+ value="9:16",
1621
+ scale=3
1622
+ )
1623
+ alignment_dropdown = gr.Dropdown(
1624
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
1625
+ value="Middle",
1626
+ label="Alignment",
1627
+ )
1628
+ resize_option = gr.Radio(
1629
+ label="Resize input image",
1630
+ choices=["Full", "75%", "50%", "33%", "25%", "Custom"],
1631
+ value="75%"
1632
+ )
1633
+ custom_resize_percentage = gr.Slider(
1634
+ label="Custom resize (%)",
1635
+ minimum=1,
1636
+ maximum=100,
1637
+ step=1,
1638
+ value=50,
1639
+ visible=False
1640
+ )
1641
+
1642
+ fill_result = gr.Image(
1643
+ interactive=False,
1644
+ label="Generated Image",
1645
+ )
1646
+
1647
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
1648
+ with gr.Column():
1649
+ with gr.Row():
1650
+ width_slider = gr.Slider(
1651
+ label="Target Width",
1652
+ minimum=720,
1653
+ maximum=1536,
1654
+ step=8,
1655
+ value=720,
1656
+ )
1657
+ height_slider = gr.Slider(
1658
+ label="Target Height",
1659
+ minimum=720,
1660
+ maximum=1536,
1661
+ step=8,
1662
+ value=1280,
1663
+ )
1664
+
1665
+ num_inference_steps = gr.Slider(label="Steps", minimum=2, maximum=50, step=1, value=18)
1666
+ with gr.Group():
1667
+ overlap_percentage = gr.Slider(
1668
+ label="Mask overlap (%)",
1669
+ minimum=1,
1670
+ maximum=50,
1671
+ value=10,
1672
+ step=1
1673
+ )
1674
+ with gr.Row():
1675
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
1676
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
1677
+ with gr.Row():
1678
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
1679
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
1680
+
1681
+
1682
+ with gr.Row():
1683
+ with gr.Group():
1684
+ prompt = gr.Textbox(label="Prompt")
1685
+ bg_source = gr.Radio(choices=[e.value for e in list(BGSource)[2:]],
1686
+ value=BGSource.LEFT.value,
1687
+ label="Lighting Preference (Initial Latent)", type='value')
1688
+
1689
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
1690
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
1691
+ with gr.Row():
1692
+ relight_button = gr.Button(value="Relight")
1693
+
1694
+ with gr.Group(visible=False):
1695
+ with gr.Row():
1696
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1697
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1698
+
1699
+ with gr.Row():
1700
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1701
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1702
+
1703
+ with gr.Accordion("Advanced options", open=False):
1704
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
1705
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01, visible=False)
1706
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
1707
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
1708
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
1709
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality', visible=False)
1710
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality', visible=False)
1711
+ x_slider = gr.Slider(
1712
+ minimum=0,
1713
+ maximum=1000,
1714
+ label="X Position",
1715
+ value=500,
1716
+ visible=False
1717
+ )
1718
+ y_slider = gr.Slider(
1719
+ minimum=0,
1720
+ maximum=1000,
1721
+ label="Y Position",
1722
+ value=500,
1723
+ visible=False
1724
+ )
1725
+
1726
+ # with gr.Row():
1727
+
1728
+ # gr.Examples(
1729
+ # fn=lambda *args: ([args[-1]], None),
1730
+ # examples=db_examples.foreground_conditioned_examples,
1731
+ # inputs=[
1732
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
1733
+ # ],
1734
+ # outputs=[result_gallery, output_bg],
1735
+ # run_on_click=True, examples_per_page=1024
1736
+ # )
1737
+ ips = [extracted_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
1738
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
1739
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1740
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1741
+
1742
+ # def use_output_as_input(output_image):
1743
+ # return output_image
1744
+
1745
+ # use_as_input_button.click(
1746
+ # fn=use_output_as_input,
1747
+ # inputs=[fill_result],
1748
+ # outputs=[input_image]
1749
+ # )
1750
+
1751
+ target_ratio.change(
1752
+ fn=preload_presets,
1753
+ inputs=[target_ratio, width_slider, height_slider],
1754
+ outputs=[width_slider, height_slider, settings_panel],
1755
+ queue=False
1756
+ )
1757
+
1758
+ width_slider.change(
1759
+ fn=select_the_right_preset,
1760
+ inputs=[width_slider, height_slider],
1761
+ outputs=[target_ratio],
1762
+ queue=False
1763
+ )
1764
+
1765
+ height_slider.change(
1766
+ fn=select_the_right_preset,
1767
+ inputs=[width_slider, height_slider],
1768
+ outputs=[target_ratio],
1769
+ queue=False
1770
+ )
1771
+
1772
+ resize_option.change(
1773
+ fn=toggle_custom_resize_slider,
1774
+ inputs=[resize_option],
1775
+ outputs=[custom_resize_percentage],
1776
+ queue=False
1777
+ )
1778
+
1779
+ fill_button.click(
1780
+ fn=clear_result,
1781
+ inputs=None,
1782
+ outputs=fill_result,
1783
+ ).then(
1784
+ fn=inpaint,
1785
+ inputs=[dummy_image_for_outputs, width_slider, height_slider, overlap_percentage, num_inference_steps,
1786
+ resize_option, custom_resize_percentage, prompt_fill, alignment_dropdown,
1787
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
1788
+ outputs=[fill_result])
1789
+ # ).then(
1790
+ # fn=lambda: gr.update(visible=True),
1791
+ # inputs=None,
1792
+ # outputs=use_as_input_button,
1793
+ # )
1794
+
1795
+ prompt_fill.submit(
1796
+ fn=clear_result,
1797
+ inputs=None,
1798
+ outputs=fill_result,
1799
+ ).then(
1800
+ fn=inpaint,
1801
+ inputs=[dummy_image_for_outputs, width_slider, height_slider, overlap_percentage, num_inference_steps,
1802
+ resize_option, custom_resize_percentage, prompt_fill, alignment_dropdown,
1803
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
1804
+ outputs=[fill_result])
1805
+
1806
+ def convert_to_pil(image):
1807
+ try:
1808
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
1809
+ image = image.astype(np.uint8)
1810
+ logging.info(f"Converted image shape: {image.shape}, dtype: {image.dtype}")
1811
+ return image
1812
+ except Exception as e:
1813
+ logging.error(f"Error converting image: {e}")
1814
+ return image
1815
+
1816
+ run_button.click(
1817
+ fn=convert_to_pil,
1818
+ inputs=extracted_fg, # This is already RGBA with removed background
1819
+ outputs=angles_fg
1820
+ ).then(
1821
+ fn=infer,
1822
+ inputs=[
1823
+ text_prompt,
1824
+ extracted_fg, # Already processed RGBA image
1825
+ ],
1826
+ outputs=[orientation_result],
1827
+ )
1828
+
1829
+ find_objects_button.click(
1830
+ fn=process_image,
1831
+ inputs=[input_fg, text_prompt],
1832
+ outputs=[extracted_objects, extracted_fg]
1833
+ )
1834
+
1835
+ extract_button.click(
1836
+ fn=extract_foreground,
1837
+ inputs=[input_fg],
1838
+ outputs=[extracted_fg, x_slider, y_slider]
1839
+ )
1840
+ with gr.Tab("Style Transfer"):
1841
+ gr.Markdown("## Apply the style of an image to another one")
1842
+ with gr.Row():
1843
+ with gr.Column():
1844
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
1845
+ with gr.Row():
1846
+ with gr.Group():
1847
+ structure_image = gr.Image(label="Structure Image", type="filepath")
1848
+ depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
1849
+ with gr.Group():
1850
+ style_image = gr.Image(label="Style Image", type="filepath")
1851
+ style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
1852
+ generate_btn = gr.Button("Generate")
1853
+
1854
+ gr.Examples(
1855
+ examples=examples,
1856
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1857
+ outputs=[output_image],
1858
+ fn=generate_image,
1859
+ cache_examples=True,
1860
+ cache_mode="lazy"
1861
+ )
1862
+
1863
+ with gr.Column():
1864
+ output_image.render()
1865
+ transfer_btn = gr.Button("Send to relight")
1866
+
1867
+ def send_img(img_result):
1868
+ return img_result
1869
+
1870
+ transfer_btn.click(send_img, [output_image], [input_fg])
1871
+
1872
+
1873
+ generate_btn.click(
1874
+ fn=generate_image,
1875
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1876
+ outputs=[output_image]
1877
+ )
1878
+
1879
+ if __name__ == "__main__":
1880
+ app.launch(share=True)