Ashoka74 commited on
Commit
6ab2279
ยท
verified ยท
1 Parent(s): 1ae6f4c

Create merged_files3.py

Browse files
Files changed (1) hide show
  1. merged_files3.py +1995 -0
merged_files3.py ADDED
@@ -0,0 +1,1995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ import gc # Add garbage collector import
5
+ from typing import Sequence, Mapping, Any, Union
6
+ import torch
7
+ import gradio as gr
8
+ from PIL import Image, ImageDraw
9
+ from huggingface_hub import hf_hub_download
10
+ import spaces
11
+
12
+ import argparse
13
+ import random
14
+
15
+ import os
16
+ import math
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import safetensors.torch as sf
21
+ import datetime
22
+ from pathlib import Path
23
+ from io import BytesIO
24
+
25
+ from PIL import Image
26
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
27
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
28
+ from diffusers.models.attention_processor import AttnProcessor2_0
29
+ from transformers import CLIPTextModel, CLIPTokenizer
30
+ import dds_cloudapi_sdk
31
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
32
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
33
+ from dds_cloudapi_sdk.tasks import DetectionTarget
34
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
35
+ from transformers import AutoModelForImageSegmentation
36
+
37
+
38
+ from enum import Enum
39
+ from torch.hub import download_url_to_file
40
+ import tempfile
41
+
42
+ from sam2.build_sam import build_sam2
43
+
44
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
45
+ import cv2
46
+
47
+ from transformers import AutoModelForImageSegmentation
48
+ from inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
49
+ from torchvision import transforms
50
+
51
+
52
+ from typing import Optional
53
+
54
+ from depth_anything_v2.dpt import DepthAnythingV2
55
+
56
+ import httpx
57
+
58
+
59
+ import gradio as gr
60
+ import torch
61
+ from diffusers import FluxFillPipeline
62
+ from diffusers.utils import load_image
63
+ from PIL import Image, ImageDraw
64
+ import numpy as np
65
+ import spaces
66
+ from huggingface_hub import hf_hub_download
67
+ import openai
68
+ from openai import OpenAI
69
+ import gradio as gr
70
+ import os
71
+ from PIL import Image
72
+ import numpy as np
73
+ import io
74
+ import base64
75
+
76
+
77
+ MAX_IMAGE_WIDTH = 2048
78
+ IMAGE_FORMAT = "JPEG"
79
+
80
+
81
+
82
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
83
+ NUM_VIEWS = 6
84
+ HEIGHT = 768
85
+ WIDTH = 768
86
+ MAX_SEED = np.iinfo(np.int32).max
87
+
88
+
89
+
90
+ import supervision as sv
91
+ import torch
92
+ from PIL import Image
93
+
94
+ import logging
95
+
96
+ # Configure logging
97
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
98
+
99
+ transform_image = transforms.Compose(
100
+ [
101
+ transforms.Resize((1024, 1024)),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
104
+ ]
105
+ )
106
+
107
+ #hf_hub_download(repo_id="YarvixPA/FLUX.1-Fill-dev-gguf", filename="flux1-fill-dev-Q5_K_S.gguf", local_dir="models/")
108
+
109
+
110
+
111
+ # Load
112
+
113
+ # Model paths
114
+ model_path = './models/iclight_sd15_fc.safetensors'
115
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
116
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
117
+ model_path4 = './checkpoints/config.json'
118
+ model_path5 = './checkpoints/preprocessor_config.json'
119
+ model_path6 = './configs/sam2_hiera_l.yaml'
120
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
121
+
122
+ # Base URL for the repository
123
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
124
+
125
+ # Model URLs
126
+ model_urls = {
127
+ model_path: 'iclight_sd15_fc.safetensors',
128
+ model_path2: 'depth_anything_v2_vits.pth',
129
+ model_path3: 'sam2_hiera_large.pt',
130
+ model_path4: 'config.json',
131
+ model_path5: 'preprocessor_config.json',
132
+ model_path6: 'sam2_hiera_l.yaml',
133
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
134
+ }
135
+
136
+ # Ensure directories exist
137
+ def ensure_directories():
138
+ for path in model_urls.keys():
139
+ os.makedirs(os.path.dirname(path), exist_ok=True)
140
+
141
+ # Download models
142
+ def download_models():
143
+ for local_path, filename in model_urls.items():
144
+ if not os.path.exists(local_path):
145
+ try:
146
+ url = f"{BASE_URL}{filename}"
147
+ print(f"Downloading {filename}")
148
+ download_url_to_file(url, local_path)
149
+ print(f"Successfully downloaded {filename}")
150
+ except Exception as e:
151
+ print(f"Error downloading {filename}: {e}")
152
+
153
+ ensure_directories()
154
+ download_models()
155
+
156
+
157
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
158
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
159
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
160
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
161
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
162
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
163
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
164
+
165
+
166
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
167
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
168
+
169
+ # Load models in sequence with memory clearing between loads
170
+ def load_models():
171
+ clear_memory()
172
+
173
+ global text_encoder, vae, unet, rmbg
174
+
175
+ # Load text encoder
176
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
177
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
178
+ clear_memory()
179
+
180
+ # Load VAE
181
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
182
+ vae = vae.to(device=device, dtype=dtype)
183
+ clear_memory()
184
+
185
+ # Load UNet
186
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
187
+ unet = unet.to(device=device, dtype=dtype)
188
+ clear_memory()
189
+
190
+ # Load RMBG
191
+ rmbg = AutoModelForImageSegmentation.from_pretrained(
192
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
193
+ )
194
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
195
+ clear_memory()
196
+
197
+ # Call load_models after device setup
198
+ load_models()
199
+
200
+ from diffusers import FluxTransformer2DModel, FluxFillPipeline, GGUFQuantizationConfig
201
+ from transformers import T5EncoderModel
202
+ import torch
203
+
204
+
205
+
206
+ ckpt_path = (
207
+ "https://huggingface.co/SporkySporkness/FLUX.1-Fill-dev-GGUF/flux1-fill-dev-fp16-Q5_0-GGUF.gguf"
208
+ )
209
+
210
+ transformer = FluxTransformer2DModel.from_single_file(
211
+ ckpt_path,
212
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
213
+ torch_dtype=torch.bfloat16,
214
+ )
215
+
216
+ fill_pipe = FluxFillPipeline.from_pretrained(
217
+ "black-forest-labs/FLUX.1-Fill-dev",
218
+ transformer=transformer,
219
+ generator=torch.manual_seed(0),
220
+ torch_dtype=torch.bfloat16,
221
+ )
222
+
223
+
224
+ try:
225
+ import xformers
226
+ import xformers.ops
227
+ XFORMERS_AVAILABLE = True
228
+ print("xformers is available - Using memory efficient attention")
229
+ except ImportError:
230
+ XFORMERS_AVAILABLE = False
231
+ print("xformers not available - Using default attention")
232
+
233
+ # Memory optimizations for RTX 2070
234
+ torch.backends.cudnn.benchmark = True
235
+ if torch.cuda.is_available():
236
+ torch.backends.cuda.matmul.allow_tf32 = True
237
+ torch.backends.cudnn.allow_tf32 = True
238
+ # Set a smaller attention slice size for RTX 2070
239
+ torch.backends.cuda.max_split_size_mb = 256 # Reduced from 512 for better stability
240
+ device = torch.device('cuda')
241
+ else:
242
+ device = torch.device('cpu')
243
+
244
+
245
+ rmbg = AutoModelForImageSegmentation.from_pretrained(
246
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
247
+ )
248
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
249
+
250
+
251
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
252
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
253
+ model = model.to(device)
254
+ model.eval()
255
+
256
+
257
+ with torch.no_grad():
258
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
259
+ new_conv_in.weight.zero_()
260
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
261
+ new_conv_in.bias = unet.conv_in.bias
262
+ unet.conv_in = new_conv_in
263
+
264
+
265
+ unet_original_forward = unet.forward
266
+
267
+
268
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
269
+ if alignment in ("Left", "Right") and source_width >= target_width:
270
+ return False
271
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
272
+ return False
273
+ return True
274
+
275
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
276
+ target_size = (width, height)
277
+
278
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
279
+ new_width = int(image.width * scale_factor)
280
+ new_height = int(image.height * scale_factor)
281
+
282
+ source = image.resize((new_width, new_height), Image.LANCZOS)
283
+
284
+ if resize_option == "Full":
285
+ resize_percentage = 100
286
+ elif resize_option == "75%":
287
+ resize_percentage = 75
288
+ elif resize_option == "50%":
289
+ resize_percentage = 50
290
+ elif resize_option == "33%":
291
+ resize_percentage = 33
292
+ elif resize_option == "25%":
293
+ resize_percentage = 25
294
+ else: # Custom
295
+ resize_percentage = custom_resize_percentage
296
+
297
+ # Calculate new dimensions based on percentage
298
+ resize_factor = resize_percentage / 100
299
+ new_width = int(source.width * resize_factor)
300
+ new_height = int(source.height * resize_factor)
301
+
302
+ # Ensure minimum size of 64 pixels
303
+ new_width = max(new_width, 64)
304
+ new_height = max(new_height, 64)
305
+
306
+ # Resize the image
307
+ source = source.resize((new_width, new_height), Image.LANCZOS)
308
+
309
+ # Calculate the overlap in pixels based on the percentage
310
+ overlap_x = int(new_width * (overlap_percentage / 100))
311
+ overlap_y = int(new_height * (overlap_percentage / 100))
312
+
313
+ # Ensure minimum overlap of 1 pixel
314
+ overlap_x = max(overlap_x, 1)
315
+ overlap_y = max(overlap_y, 1)
316
+
317
+ # Calculate margins based on alignment
318
+ if alignment == "Middle":
319
+ margin_x = (target_size[0] - new_width) // 2
320
+ margin_y = (target_size[1] - new_height) // 2
321
+ elif alignment == "Left":
322
+ margin_x = 0
323
+ margin_y = (target_size[1] - new_height) // 2
324
+ elif alignment == "Right":
325
+ margin_x = target_size[0] - new_width
326
+ margin_y = (target_size[1] - new_height) // 2
327
+ elif alignment == "Top":
328
+ margin_x = (target_size[0] - new_width) // 2
329
+ margin_y = 0
330
+ elif alignment == "Bottom":
331
+ margin_x = (target_size[0] - new_width) // 2
332
+ margin_y = target_size[1] - new_height
333
+
334
+ # Adjust margins to eliminate gaps
335
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
336
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
337
+
338
+ # Create a new background image and paste the resized source image
339
+ background = Image.new('RGB', target_size, (255, 255, 255))
340
+ background.paste(source, (margin_x, margin_y))
341
+
342
+ # Create the mask
343
+ mask = Image.new('L', target_size, 255)
344
+ mask_draw = ImageDraw.Draw(mask)
345
+
346
+ # Calculate overlap areas
347
+ white_gaps_patch = 2
348
+
349
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
350
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
351
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
352
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
353
+
354
+ if alignment == "Left":
355
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
356
+ elif alignment == "Right":
357
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
358
+ elif alignment == "Top":
359
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
360
+ elif alignment == "Bottom":
361
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
362
+
363
+ # Draw the mask
364
+ mask_draw.rectangle([
365
+ (left_overlap, top_overlap),
366
+ (right_overlap, bottom_overlap)
367
+ ], fill=0)
368
+
369
+ return background, mask
370
+
371
+ @spaces.GPU(duration=60)
372
+ @torch.inference_mode()
373
+ def inpaint(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
374
+ clear_memory()
375
+
376
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
377
+
378
+ if not can_expand(background.width, background.height, width, height, alignment):
379
+ alignment = "Middle"
380
+
381
+ cnet_image = background.copy()
382
+ cnet_image.paste(0, (0, 0), mask)
383
+
384
+ final_prompt = prompt_input
385
+
386
+ #generator = torch.Generator(device="cuda").manual_seed(42)
387
+
388
+ result = fill_pipe(
389
+ prompt=final_prompt,
390
+ height=height,
391
+ width=width,
392
+ image=cnet_image,
393
+ mask_image=mask,
394
+ num_inference_steps=num_inference_steps,
395
+ guidance_scale=30,
396
+ ).images[0]
397
+
398
+ result = result.convert("RGBA")
399
+ cnet_image.paste(result, (0, 0), mask)
400
+
401
+ return cnet_image #, background
402
+
403
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
404
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
405
+
406
+ preview = background.copy().convert('RGBA')
407
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64))
408
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
409
+ red_mask.paste(red_overlay, (0, 0), mask)
410
+ preview = Image.alpha_composite(preview, red_mask)
411
+
412
+ return preview
413
+
414
+ def clear_result():
415
+ return gr.update(value=None)
416
+
417
+ def preload_presets(target_ratio, ui_width, ui_height):
418
+ if target_ratio == "9:16":
419
+ return 720, 1280, gr.update()
420
+ elif target_ratio == "16:9":
421
+ return 1280, 720, gr.update()
422
+ elif target_ratio == "1:1":
423
+ return 1024, 1024, gr.update()
424
+ elif target_ratio == "Custom":
425
+ return ui_width, ui_height, gr.update(open=True)
426
+
427
+ def select_the_right_preset(user_width, user_height):
428
+ if user_width == 720 and user_height == 1280:
429
+ return "9:16"
430
+ elif user_width == 1280 and user_height == 720:
431
+ return "16:9"
432
+ elif user_width == 1024 and user_height == 1024:
433
+ return "1:1"
434
+ else:
435
+ return "Custom"
436
+
437
+ def toggle_custom_resize_slider(resize_option):
438
+ return gr.update(visible=(resize_option == "Custom"))
439
+
440
+ def update_history(new_image, history):
441
+ if history is None:
442
+ history = []
443
+ history.insert(0, new_image)
444
+ return history
445
+
446
+
447
+ def enable_efficient_attention():
448
+ if XFORMERS_AVAILABLE:
449
+ try:
450
+ unet.set_use_memory_efficient_attention_xformers(True)
451
+ vae.set_use_memory_efficient_attention_xformers(True)
452
+ print("Enabled xformers memory efficient attention")
453
+ except Exception as e:
454
+ print(f"Xformers error: {e}")
455
+ print("Falling back to sliced attention")
456
+ unet.set_attn_processor(AttnProcessor2_0())
457
+ vae.set_attn_processor(AttnProcessor2_0())
458
+ else:
459
+ print("Using sliced attention")
460
+ unet.set_attn_processor(AttnProcessor2_0())
461
+ vae.set_attn_processor(AttnProcessor2_0())
462
+
463
+ # Add memory clearing function
464
+ def clear_memory():
465
+ if torch.cuda.is_available():
466
+ torch.cuda.empty_cache()
467
+ torch.cuda.synchronize()
468
+ gc.collect() # Added garbage collection
469
+
470
+ # Configure pipelines with optimized settings
471
+ fill_pipe.enable_model_cpu_offload()
472
+ fill_pipe.enable_vae_slicing()
473
+
474
+ pipe.enable_model_cpu_offload()
475
+ pipe.enable_vae_slicing()
476
+
477
+ t2i_pipe.enable_model_cpu_offload()
478
+ t2i_pipe.enable_vae_slicing()
479
+
480
+ i2i_pipe.enable_model_cpu_offload()
481
+ i2i_pipe.enable_vae_slicing()
482
+
483
+ # Enable efficient attention after all models are loaded
484
+ enable_efficient_attention()
485
+
486
+
487
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
488
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
489
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
490
+ new_sample = torch.cat([sample, c_concat], dim=1)
491
+ kwargs['cross_attention_kwargs'] = {}
492
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
493
+
494
+
495
+ unet.forward = hooked_unet_forward
496
+
497
+
498
+ sd_offset = sf.load_file(model_path)
499
+ sd_origin = unet.state_dict()
500
+ keys = sd_origin.keys()
501
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
502
+ unet.load_state_dict(sd_merged, strict=True)
503
+ del sd_offset, sd_origin, sd_merged, keys
504
+
505
+
506
+ # Device and dtype setup
507
+ device = torch.device('cuda')
508
+ dtype = torch.float16 # Use float16 consistently for all models
509
+
510
+ pipe = prepare_pipeline(
511
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
512
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
513
+ unet_model=None,
514
+ lora_model=None,
515
+ adapter_path="huanngzh/mv-adapter",
516
+ scheduler=None,
517
+ num_views=NUM_VIEWS,
518
+ device=device,
519
+ dtype=dtype,
520
+ )
521
+
522
+ pipe.enable_model_cpu_offload()
523
+ pipe.enable_vae_slicing()
524
+ #pipe.enable_xformers_memory_efficient_attention()
525
+
526
+ # Move models to device with consistent dtype
527
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
528
+ vae = vae.to(device=device, dtype=dtype)
529
+ unet = unet.to(device=device, dtype=dtype)
530
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32 for accuracy
531
+
532
+ ddim_scheduler = DDIMScheduler(
533
+ num_train_timesteps=1000,
534
+ beta_start=0.00085,
535
+ beta_end=0.012,
536
+ beta_schedule="scaled_linear",
537
+ clip_sample=False,
538
+ set_alpha_to_one=False,
539
+ steps_offset=1,
540
+ )
541
+
542
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
543
+ num_train_timesteps=1000,
544
+ beta_start=0.00085,
545
+ beta_end=0.012,
546
+ steps_offset=1
547
+ )
548
+
549
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
550
+ num_train_timesteps=1000,
551
+ beta_start=0.00085,
552
+ beta_end=0.012,
553
+ algorithm_type="sde-dpmsolver++",
554
+ use_karras_sigmas=True,
555
+ steps_offset=1
556
+ )
557
+
558
+ # Pipelines
559
+
560
+
561
+ t2i_pipe = StableDiffusionPipeline(
562
+ vae=vae,
563
+ text_encoder=text_encoder,
564
+ tokenizer=tokenizer,
565
+ unet=unet,
566
+ scheduler=dpmpp_2m_sde_karras_scheduler,
567
+ safety_checker=None,
568
+ requires_safety_checker=False,
569
+ feature_extractor=None,
570
+ image_encoder=None
571
+ )
572
+
573
+ t2i_pipe.enable_model_cpu_offload()
574
+ t2i_pipe.enable_vae_slicing()
575
+ #t2i_pipe.enable_xformers_memory_efficient_attention()
576
+
577
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
578
+ vae=vae,
579
+ text_encoder=text_encoder,
580
+ tokenizer=tokenizer,
581
+ unet=unet,
582
+ scheduler=dpmpp_2m_sde_karras_scheduler,
583
+ safety_checker=None,
584
+ requires_safety_checker=False,
585
+ feature_extractor=None,
586
+ image_encoder=None
587
+ )
588
+
589
+ i2i_pipe.enable_model_cpu_offload()
590
+ i2i_pipe.enable_vae_slicing()
591
+ #i2i_pipe.enable_xformers_memory_efficient_attention()
592
+
593
+ @torch.inference_mode()
594
+ def encode_prompt_inner(txt: str):
595
+ max_length = tokenizer.model_max_length
596
+ chunk_length = tokenizer.model_max_length - 2
597
+ id_start = tokenizer.bos_token_id
598
+ id_end = tokenizer.eos_token_id
599
+ id_pad = id_end
600
+
601
+ def pad(x, p, i):
602
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
603
+
604
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
605
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
606
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
607
+
608
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
609
+ conds = text_encoder(token_ids).last_hidden_state
610
+
611
+ return conds
612
+
613
+
614
+ @torch.inference_mode()
615
+ def encode_prompt_pair(positive_prompt, negative_prompt):
616
+ c = encode_prompt_inner(positive_prompt)
617
+ uc = encode_prompt_inner(negative_prompt)
618
+
619
+ c_len = float(len(c))
620
+ uc_len = float(len(uc))
621
+ max_count = max(c_len, uc_len)
622
+ c_repeat = int(math.ceil(max_count / c_len))
623
+ uc_repeat = int(math.ceil(max_count / uc_len))
624
+ max_chunk = max(len(c), len(uc))
625
+
626
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
627
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
628
+
629
+ c = torch.cat([p[None, ...] for p in c], dim=1)
630
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
631
+
632
+ return c, uc
633
+
634
+
635
+ @spaces.GPU(duration=60)
636
+ @torch.inference_mode()
637
+ def infer(
638
+ prompt,
639
+ image, # This is already RGBA with background removed
640
+ do_rembg=True,
641
+ seed=42,
642
+ randomize_seed=False,
643
+ guidance_scale=3.0,
644
+ num_inference_steps=30,
645
+ reference_conditioning_scale=1.0,
646
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
647
+ progress=gr.Progress(track_tqdm=True),
648
+ ):
649
+ clear_memory()
650
+
651
+ # Convert input to PIL if needed
652
+ if isinstance(image, np.ndarray):
653
+ if image.shape[-1] == 4: # RGBA
654
+ image = Image.fromarray(image, 'RGBA')
655
+ else: # RGB
656
+ image = Image.fromarray(image, 'RGB')
657
+
658
+ #logging.info(f"Converted to PIL Image mode: {image.mode}")
659
+
660
+ # No need for remove_bg_fn since image is already processed
661
+ remove_bg_fn = None
662
+
663
+ if randomize_seed:
664
+ seed = random.randint(0, MAX_SEED)
665
+
666
+ images, preprocessed_image = run_pipeline(
667
+ pipe,
668
+ num_views=NUM_VIEWS,
669
+ text=prompt,
670
+ image=image,
671
+ height=HEIGHT,
672
+ width=WIDTH,
673
+ num_inference_steps=num_inference_steps,
674
+ guidance_scale=guidance_scale,
675
+ seed=seed,
676
+ remove_bg_fn=remove_bg_fn, # Set to None since preprocessing is done
677
+ reference_conditioning_scale=reference_conditioning_scale,
678
+ negative_prompt=negative_prompt,
679
+ device=device,
680
+ )
681
+
682
+ # logging.info(f"Output images shape: {[img.shape for img in images]}")
683
+ # logging.info(f"Preprocessed image shape: {preprocessed_image.shape if preprocessed_image is not None else None}")
684
+ return images
685
+
686
+
687
+ @spaces.GPU(duration=60)
688
+ @torch.inference_mode()
689
+ def pytorch2numpy(imgs, quant=True):
690
+ results = []
691
+ for x in imgs:
692
+ y = x.movedim(0, -1)
693
+
694
+ if quant:
695
+ y = y * 127.5 + 127.5
696
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
697
+ else:
698
+ y = y * 0.5 + 0.5
699
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
700
+
701
+ results.append(y)
702
+ return results
703
+
704
+ @spaces.GPU(duration=60)
705
+ @torch.inference_mode()
706
+ def numpy2pytorch(imgs):
707
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
708
+ h = h.movedim(-1, 1)
709
+ return h.to(device=device, dtype=dtype)
710
+
711
+
712
+ def resize_and_center_crop(image, target_width, target_height):
713
+ pil_image = Image.fromarray(image)
714
+ original_width, original_height = pil_image.size
715
+ scale_factor = max(target_width / original_width, target_height / original_height)
716
+ resized_width = int(round(original_width * scale_factor))
717
+ resized_height = int(round(original_height * scale_factor))
718
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
719
+ left = (resized_width - target_width) / 2
720
+ top = (resized_height - target_height) / 2
721
+ right = (resized_width + target_width) / 2
722
+ bottom = (resized_height + target_height) / 2
723
+ cropped_image = resized_image.crop((left, top, right, bottom))
724
+ return np.array(cropped_image)
725
+
726
+
727
+ def resize_without_crop(image, target_width, target_height):
728
+ pil_image = Image.fromarray(image)
729
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
730
+ return np.array(resized_image)
731
+
732
+
733
+ @spaces.GPU
734
+ @torch.inference_mode()
735
+ def run_rmbg(image):
736
+ clear_memory()
737
+ image_size = image.size
738
+ input_images = transform_image(image).unsqueeze(0).to(device, dtype=torch.float32)
739
+ # Prediction
740
+ with torch.no_grad():
741
+ preds = rmbg(input_images)[-1].sigmoid().cpu()
742
+ pred = preds[0].squeeze()
743
+ pred_pil = transforms.ToPILImage()(pred)
744
+ mask = pred_pil.resize(image_size)
745
+ image.putalpha(mask)
746
+ return image
747
+
748
+
749
+
750
+ def preprocess_image(image: Image.Image, height=768, width=768):
751
+ image = np.array(image)
752
+ alpha = image[..., 3] > 0
753
+ H, W = alpha.shape
754
+ # get the bounding box of alpha
755
+ y, x = np.where(alpha)
756
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
757
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
758
+ image_center = image[y0:y1, x0:x1]
759
+ # resize the longer side to# resize the longer side to H * 0.9
760
+ H, W, _ = image_center.shape
761
+ if H > W:
762
+ W = int(W * (height * 0.9) / H)
763
+ H = int(height * 0.9)
764
+ else:
765
+ H = int(H * (width * 0.9) / W)
766
+ W = int(width * 0.9)
767
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
768
+ # pad to H, W
769
+ start_h = (height - H) // 2
770
+ start_w = (width - W) // 2
771
+ image = np.zeros((height, width, 4), dtype=np.uint8)
772
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
773
+ image = image.astype(np.float32) / 255.0
774
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
775
+ image = (image * 255).clip(0, 255).astype(np.uint8)
776
+ image = Image.fromarray(image)
777
+ return image
778
+
779
+
780
+ @spaces.GPU(duration=60)
781
+ @torch.inference_mode()
782
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
783
+ clear_memory()
784
+
785
+ # Get input dimensions
786
+ input_height, input_width = input_fg.shape[:2]
787
+
788
+ bg_source = BGSource(bg_source)
789
+
790
+ if bg_source == BGSource.UPLOAD:
791
+ pass
792
+ elif bg_source == BGSource.UPLOAD_FLIP:
793
+ input_bg = np.fliplr(input_bg)
794
+ if bg_source == BGSource.GREY:
795
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
796
+ elif bg_source == BGSource.LEFT:
797
+ gradient = np.linspace(255, 0, input_width)
798
+ image = np.tile(gradient, (input_height, 1))
799
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
800
+ elif bg_source == BGSource.RIGHT:
801
+ gradient = np.linspace(0, 255, input_width)
802
+ image = np.tile(gradient, (input_height, 1))
803
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
804
+ elif bg_source == BGSource.TOP:
805
+ gradient = np.linspace(255, 0, input_height)[:, None]
806
+ image = np.tile(gradient, (1, input_width))
807
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
808
+ elif bg_source == BGSource.BOTTOM:
809
+ gradient = np.linspace(0, 255, input_height)[:, None]
810
+ image = np.tile(gradient, (1, input_width))
811
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
812
+ else:
813
+ raise 'Wrong initial latent!'
814
+
815
+ rng = torch.Generator(device=device).manual_seed(int(seed))
816
+
817
+ # Use input dimensions directly
818
+ fg = resize_without_crop(input_fg, input_width, input_height)
819
+
820
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
821
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
822
+
823
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
824
+
825
+ if input_bg is None:
826
+ latents = t2i_pipe(
827
+ prompt_embeds=conds,
828
+ negative_prompt_embeds=unconds,
829
+ width=input_width,
830
+ height=input_height,
831
+ num_inference_steps=steps,
832
+ num_images_per_prompt=num_samples,
833
+ generator=rng,
834
+ output_type='latent',
835
+ guidance_scale=cfg,
836
+ cross_attention_kwargs={'concat_conds': concat_conds},
837
+ ).images.to(vae.dtype) / vae.config.scaling_factor
838
+ else:
839
+ bg = resize_without_crop(input_bg, input_width, input_height)
840
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
841
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
842
+ latents = i2i_pipe(
843
+ image=bg_latent,
844
+ strength=lowres_denoise,
845
+ prompt_embeds=conds,
846
+ negative_prompt_embeds=unconds,
847
+ width=input_width,
848
+ height=input_height,
849
+ num_inference_steps=int(round(steps / lowres_denoise)),
850
+ num_images_per_prompt=num_samples,
851
+ generator=rng,
852
+ output_type='latent',
853
+ guidance_scale=cfg,
854
+ cross_attention_kwargs={'concat_conds': concat_conds},
855
+ ).images.to(vae.dtype) / vae.config.scaling_factor
856
+
857
+ pixels = vae.decode(latents).sample
858
+ pixels = pytorch2numpy(pixels)
859
+ pixels = [resize_without_crop(
860
+ image=p,
861
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
862
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
863
+ for p in pixels]
864
+
865
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
866
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
867
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
868
+
869
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
870
+
871
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
872
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
873
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
874
+
875
+ latents = i2i_pipe(
876
+ image=latents,
877
+ strength=highres_denoise,
878
+ prompt_embeds=conds,
879
+ negative_prompt_embeds=unconds,
880
+ width=highres_width,
881
+ height=highres_height,
882
+ num_inference_steps=int(round(steps / highres_denoise)),
883
+ num_images_per_prompt=num_samples,
884
+ generator=rng,
885
+ output_type='latent',
886
+ guidance_scale=cfg,
887
+ cross_attention_kwargs={'concat_conds': concat_conds},
888
+ ).images.to(vae.dtype) / vae.config.scaling_factor
889
+
890
+ pixels = vae.decode(latents).sample
891
+ pixels = pytorch2numpy(pixels)
892
+
893
+ # Resize back to input dimensions
894
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
895
+ pixels = np.stack(pixels)
896
+
897
+ return pixels
898
+
899
+ @spaces.GPU(duration=60)
900
+ @torch.inference_mode()
901
+ def extract_foreground(image):
902
+ if image is None:
903
+ return None, gr.update(visible=True), gr.update(visible=True)
904
+ clear_memory()
905
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
906
+ #result, rgba = run_rmbg(image)
907
+ result = run_rmbg(image)
908
+ result = preprocess_image(result)
909
+ #logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
910
+ #logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
911
+ return result, gr.update(visible=True), gr.update(visible=True)
912
+
913
+ def update_extracted_fg_height(selected_image: gr.SelectData):
914
+ if selected_image:
915
+ # Get the height of the selected image
916
+ height = selected_image.value['image']['shape'][0] # Assuming the image is in numpy format
917
+ return gr.update(height=height) # Update the height of extracted_fg
918
+ return gr.update(height=480) # Default height if no image is selected
919
+
920
+
921
+
922
+ @torch.inference_mode()
923
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
924
+ # Convert input foreground from PIL to NumPy array if it's in PIL format
925
+ if isinstance(input_fg, Image.Image):
926
+ input_fg = np.array(input_fg)
927
+ logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
928
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
929
+ logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
930
+ return results
931
+
932
+ quick_prompts = [
933
+ 'sunshine from window',
934
+ 'golden time',
935
+ 'natural lighting',
936
+ 'warm atmosphere, at home, bedroom',
937
+ 'shadow from window',
938
+ 'soft studio lighting',
939
+ 'home atmosphere, cozy bedroom illumination',
940
+ ]
941
+ quick_prompts = [[x] for x in quick_prompts]
942
+
943
+ quick_subjects = [
944
+ 'modern sofa, high quality leather',
945
+ 'elegant dining table, polished wood',
946
+ 'luxurious bed, premium mattress',
947
+ 'minimalist office desk, clean design',
948
+ 'vintage wooden cabinet, antique finish',
949
+ ]
950
+ quick_subjects = [[x] for x in quick_subjects]
951
+
952
+ class BGSource(Enum):
953
+ UPLOAD = "Use Background Image"
954
+ UPLOAD_FLIP = "Use Flipped Background Image"
955
+ NONE = "None"
956
+ LEFT = "Left Light"
957
+ RIGHT = "Right Light"
958
+ TOP = "Top Light"
959
+ BOTTOM = "Bottom Light"
960
+ GREY = "Ambient"
961
+
962
+ # Add save function
963
+ def save_images(images, prefix="relight"):
964
+ # Create output directory if it doesn't exist
965
+ output_dir = Path("outputs")
966
+ output_dir.mkdir(exist_ok=True)
967
+
968
+ # Create timestamp for unique filenames
969
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
970
+
971
+ saved_paths = []
972
+ for i, img in enumerate(images):
973
+ if isinstance(img, np.ndarray):
974
+ # Convert to PIL Image if numpy array
975
+ img = Image.fromarray(img)
976
+
977
+ # Create filename with timestamp
978
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
979
+ filepath = output_dir / filename
980
+
981
+ # Save image
982
+ img.save(filepath)
983
+
984
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
985
+ return saved_paths
986
+
987
+ class MaskMover:
988
+ def __init__(self):
989
+ self.extracted_fg = None
990
+ self.original_fg = None # Store original foreground
991
+
992
+ def set_extracted_fg(self, fg_image):
993
+ """Store the extracted foreground with alpha channel"""
994
+ if isinstance(fg_image, np.ndarray):
995
+ self.extracted_fg = fg_image.copy()
996
+ self.original_fg = fg_image.copy()
997
+ else:
998
+ self.extracted_fg = np.array(fg_image)
999
+ self.original_fg = np.array(fg_image)
1000
+ return self.extracted_fg
1001
+
1002
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
1003
+ """Create composite with foreground at specified position"""
1004
+ if self.original_fg is None or background is None:
1005
+ return background
1006
+
1007
+ # Convert inputs to PIL Images
1008
+ if isinstance(background, np.ndarray):
1009
+ bg = Image.fromarray(background).convert('RGBA')
1010
+ else:
1011
+ bg = background.convert('RGBA')
1012
+
1013
+ if isinstance(self.original_fg, np.ndarray):
1014
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
1015
+ else:
1016
+ fg = self.original_fg.convert('RGBA')
1017
+
1018
+ # Scale the foreground size
1019
+ new_width = int(fg.width * scale)
1020
+ new_height = int(fg.height * scale)
1021
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
1022
+
1023
+ # Center the scaled foreground at the position
1024
+ x = int(x_pos - new_width / 2)
1025
+ y = int(y_pos - new_height / 2)
1026
+
1027
+ # Create composite
1028
+ result = bg.copy()
1029
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
1030
+
1031
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
1032
+
1033
+ def get_depth(image):
1034
+ if image is None:
1035
+ return None
1036
+ # Convert from PIL/gradio format to cv2
1037
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1038
+ # Get depth map
1039
+ depth = model.infer_image(raw_img) # HxW raw depth map
1040
+ # Normalize depth for visualization
1041
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
1042
+ # Convert to RGB for display
1043
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
1044
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
1045
+ return Image.fromarray(depth_colored)
1046
+
1047
+ from PIL import Image
1048
+
1049
+ def compress_image(image):
1050
+ # Convert Gradio image (numpy array) to PIL Image
1051
+ img = Image.fromarray(image)
1052
+
1053
+ # Resize image if dimensions are too large
1054
+ max_size = 1024 # Maximum dimension size
1055
+ if img.width > max_size or img.height > max_size:
1056
+ ratio = min(max_size/img.width, max_size/img.height)
1057
+ new_size = (int(img.width * ratio), int(img.height * ratio))
1058
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
1059
+
1060
+ quality = 95 # Start with high quality
1061
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
1062
+
1063
+ # Check file size and adjust quality if necessary
1064
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
1065
+ quality -= 5 # Decrease quality
1066
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
1067
+ if quality < 20: # Prevent quality from going too low
1068
+ break
1069
+
1070
+ # Convert back to numpy array for Gradio
1071
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
1072
+ return compressed_img
1073
+
1074
+ def use_orientation(selected_image:gr.SelectData):
1075
+ return selected_image.value['image']['path']
1076
+
1077
+
1078
+
1079
+ def generate_description(object_description,image, detail="high", max_tokens=250):
1080
+ openai_api_key = os.getenv("OPENAI_API_KEY")
1081
+ client = OpenAI(api_key=openai_api_key)
1082
+
1083
+ if image is not None:
1084
+ try:
1085
+ img = image # No need to open, directly use the PIL Image object
1086
+
1087
+ buffered = io.BytesIO()
1088
+ img.save(buffered, format=IMAGE_FORMAT)
1089
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
1090
+
1091
+ prompt = f"As if you were describing the interior design, make a detailed caption of this image in one large paragraph. Highlighting textures, furnitures, locations. This object should be included in the description :{object_description}"
1092
+
1093
+ payload = {
1094
+ "model": "gpt-4o-mini",
1095
+ "messages": [{
1096
+ "role": "user",
1097
+ "content": [
1098
+ {"type": "text", "text": prompt},
1099
+ {"type": "image_url",
1100
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}", "detail": detail}}
1101
+ ]
1102
+ }],
1103
+ "max_tokens": max_tokens
1104
+ }
1105
+
1106
+ response = client.chat.completions.create(**payload)
1107
+ return response.choices[0].message.content
1108
+ except Exception as e:
1109
+ print(e)
1110
+ else:
1111
+ try:
1112
+ prompt = f"Description: {object_description}. As if you were designing an interior, improve this sentence in one large paragraph. Highlighting textures, furnitures, locations, such that you create a coherent, visually pleasing setting."
1113
+
1114
+ payload = {
1115
+ "model": "gpt-4o-mini",
1116
+ "messages": [{
1117
+ "role": "user",
1118
+ "content": [
1119
+ {"type": "text", "text": prompt},
1120
+ ]
1121
+ }],
1122
+ "max_tokens": max_tokens
1123
+ }
1124
+
1125
+ response = client.chat.completions.create(**payload)
1126
+ return response.choices[0].message.content
1127
+ except Exception as e:
1128
+ print(e)
1129
+
1130
+
1131
+ @spaces.GPU(duration=60)
1132
+ @torch.inference_mode
1133
+ def process_image(input_image, input_text):
1134
+ """Main processing function for the Gradio interface"""
1135
+
1136
+ if isinstance(input_image, Image.Image):
1137
+ input_image = np.array(input_image)
1138
+
1139
+
1140
+ clear_memory()
1141
+
1142
+ # Initialize configs
1143
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
1144
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
1145
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
1146
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1147
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
1148
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1149
+
1150
+ HEIGHT = 768
1151
+ WIDTH = 768
1152
+
1153
+ # Initialize DDS client
1154
+ config = Config(API_TOKEN)
1155
+ client = Client(config)
1156
+
1157
+ # Process classes from text prompt
1158
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1159
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1160
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1161
+
1162
+ # Save input image to temp file and get URL
1163
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
1164
+ cv2.imwrite(tmpfile.name, input_image)
1165
+ image_url = client.upload_file(tmpfile.name)
1166
+ os.remove(tmpfile.name)
1167
+
1168
+ # Process detection results
1169
+ input_boxes = []
1170
+ masks = []
1171
+ confidences = []
1172
+ class_names = []
1173
+ class_ids = []
1174
+
1175
+ if len(input_text) == 0:
1176
+ task = DinoxTask(
1177
+ image_url=image_url,
1178
+ prompts=[TextPrompt(text="<prompt_free>")],
1179
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1180
+ )
1181
+
1182
+ client.run_task(task)
1183
+ predictions = task.result.objects
1184
+ classes = [pred.category for pred in predictions]
1185
+ classes = list(set(classes))
1186
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1187
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1188
+
1189
+ for idx, obj in enumerate(predictions):
1190
+ input_boxes.append(obj.bbox)
1191
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1192
+ confidences.append(obj.score)
1193
+ cls_name = obj.category.lower().strip()
1194
+ class_names.append(cls_name)
1195
+ class_ids.append(class_name_to_id[cls_name])
1196
+
1197
+ boxes = np.array(input_boxes)
1198
+ masks = np.array(masks)
1199
+ class_ids = np.array(class_ids)
1200
+ labels = [
1201
+ f"{class_name} {confidence:.2f}"
1202
+ for class_name, confidence
1203
+ in zip(class_names, confidences)
1204
+ ]
1205
+ detections = sv.Detections(
1206
+ xyxy=boxes,
1207
+ mask=masks.astype(bool),
1208
+ class_id=class_ids
1209
+ )
1210
+
1211
+ box_annotator = sv.BoxAnnotator()
1212
+ label_annotator = sv.LabelAnnotator()
1213
+ mask_annotator = sv.MaskAnnotator()
1214
+
1215
+ annotated_frame = input_image.copy()
1216
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1217
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1218
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1219
+
1220
+ # Create transparent mask for first detected object
1221
+ if len(detections) > 0:
1222
+ # Get first mask
1223
+ first_mask = detections.mask[0]
1224
+
1225
+ # Get original RGB image
1226
+ img = input_image.copy()
1227
+ H, W, C = img.shape
1228
+
1229
+ # Create RGBA image with default 255 alpha
1230
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1231
+ alpha[~first_mask] = 0 # 128 # for semi-transparency background
1232
+ alpha[first_mask] = 255 # Make the foreground opaque
1233
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
1234
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1235
+
1236
+ # get the bounding box of alpha
1237
+ y, x = np.where(alpha > 0)
1238
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1239
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1240
+
1241
+ image_center = rgba[y0:y1, x0:x1]
1242
+ # resize the longer side to H * 0.9
1243
+ H, W, _ = image_center.shape
1244
+ if H > W:
1245
+ W = int(W * (HEIGHT * 0.9) / H)
1246
+ H = int(HEIGHT * 0.9)
1247
+ else:
1248
+ H = int(H * (WIDTH * 0.9) / W)
1249
+ W = int(WIDTH * 0.9)
1250
+
1251
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1252
+ # pad to H, W
1253
+ start_h = (HEIGHT - H) // 2
1254
+ start_w = (WIDTH - W) // 2
1255
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1256
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1257
+ image = image.astype(np.float32) / 255.0
1258
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1259
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1260
+ image = Image.fromarray(image)
1261
+
1262
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1263
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1264
+ else:
1265
+ # Run DINO-X detection
1266
+ task = DinoxTask(
1267
+ image_url=image_url,
1268
+ prompts=[TextPrompt(text=input_text)],
1269
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1270
+ )
1271
+
1272
+ client.run_task(task)
1273
+ result = task.result
1274
+ objects = result.objects
1275
+
1276
+ predictions = task.result.objects
1277
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1278
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1279
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1280
+
1281
+ boxes = []
1282
+ masks = []
1283
+ confidences = []
1284
+ class_names = []
1285
+ class_ids = []
1286
+
1287
+ for idx, obj in enumerate(predictions):
1288
+ boxes.append(obj.bbox)
1289
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1290
+ confidences.append(obj.score)
1291
+ cls_name = obj.category.lower().strip()
1292
+ class_names.append(cls_name)
1293
+ class_ids.append(class_name_to_id[cls_name])
1294
+
1295
+ boxes = np.array(boxes)
1296
+ masks = np.array(masks)
1297
+ class_ids = np.array(class_ids)
1298
+ labels = [
1299
+ f"{class_name} {confidence:.2f}"
1300
+ for class_name, confidence
1301
+ in zip(class_names, confidences)
1302
+ ]
1303
+
1304
+ detections = sv.Detections(
1305
+ xyxy=boxes,
1306
+ mask=masks.astype(bool),
1307
+ class_id=class_ids,
1308
+ )
1309
+
1310
+ box_annotator = sv.BoxAnnotator()
1311
+ label_annotator = sv.LabelAnnotator()
1312
+ mask_annotator = sv.MaskAnnotator()
1313
+
1314
+ annotated_frame = input_image.copy()
1315
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1316
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1317
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1318
+
1319
+ # Create transparent mask for first detected object
1320
+ if len(detections) > 0:
1321
+ # Get first mask
1322
+ first_mask = detections.mask[0]
1323
+
1324
+ # Get original RGB image
1325
+ img = input_image.copy()
1326
+ H, W, C = img.shape
1327
+
1328
+ # Create RGBA image with default 255 alpha
1329
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1330
+ alpha[~first_mask] = 0 # 128 for semi-transparency background
1331
+ alpha[first_mask] = 255 # Make the foreground opaque
1332
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
1333
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1334
+ # get the bounding box of alpha
1335
+ y, x = np.where(alpha > 0)
1336
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1337
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1338
+
1339
+ image_center = rgba[y0:y1, x0:x1]
1340
+ # resize the longer side to H * 0.9
1341
+ H, W, _ = image_center.shape
1342
+ if H > W:
1343
+ W = int(W * (HEIGHT * 0.9) / H)
1344
+ H = int(HEIGHT * 0.9)
1345
+ else:
1346
+ H = int(H * (WIDTH * 0.9) / W)
1347
+ W = int(WIDTH * 0.9)
1348
+
1349
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1350
+ # pad to H, W
1351
+ start_h = (HEIGHT - H) // 2
1352
+ start_w = (WIDTH - W) // 2
1353
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1354
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1355
+ image = image.astype(np.float32) / 255.0
1356
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1357
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1358
+ image = Image.fromarray(image)
1359
+
1360
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1361
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1362
+
1363
+
1364
+
1365
+ # Import all the necessary functions from the original script
1366
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
1367
+ try:
1368
+ return obj[index]
1369
+ except KeyError:
1370
+ return obj["result"][index]
1371
+
1372
+ # Add all the necessary setup functions from the original script
1373
+ def find_path(name: str, path: str = None) -> str:
1374
+ if path is None:
1375
+ path = os.getcwd()
1376
+ if name in os.listdir(path):
1377
+ path_name = os.path.join(path, name)
1378
+ print(f"{name} found: {path_name}")
1379
+ return path_name
1380
+ parent_directory = os.path.dirname(path)
1381
+ if parent_directory == path:
1382
+ return None
1383
+ return find_path(name, parent_directory)
1384
+
1385
+ def add_comfyui_directory_to_sys_path() -> None:
1386
+ comfyui_path = find_path("ComfyUI")
1387
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
1388
+ sys.path.append(comfyui_path)
1389
+ print(f"'{comfyui_path}' added to sys.path")
1390
+
1391
+ def add_extra_model_paths() -> None:
1392
+ try:
1393
+ from main import load_extra_path_config
1394
+ except ImportError:
1395
+ from utils.extra_config import load_extra_path_config
1396
+ extra_model_paths = find_path("extra_model_paths.yaml")
1397
+ if extra_model_paths is not None:
1398
+ load_extra_path_config(extra_model_paths)
1399
+ else:
1400
+ print("Could not find the extra_model_paths config file.")
1401
+
1402
+ # Initialize paths
1403
+ add_comfyui_directory_to_sys_path()
1404
+ add_extra_model_paths()
1405
+
1406
+ def import_custom_nodes() -> None:
1407
+ import asyncio
1408
+ import execution
1409
+ from nodes import init_extra_nodes
1410
+ import server
1411
+ loop = asyncio.new_event_loop()
1412
+ asyncio.set_event_loop(loop)
1413
+ server_instance = server.PromptServer(loop)
1414
+ execution.PromptQueue(server_instance)
1415
+ init_extra_nodes()
1416
+
1417
+ # Import all necessary nodes
1418
+ from nodes import (
1419
+ StyleModelLoader,
1420
+ VAEEncode,
1421
+ NODE_CLASS_MAPPINGS,
1422
+ LoadImage,
1423
+ CLIPVisionLoader,
1424
+ SaveImage,
1425
+ VAELoader,
1426
+ CLIPVisionEncode,
1427
+ DualCLIPLoader,
1428
+ EmptyLatentImage,
1429
+ VAEDecode,
1430
+ UNETLoader,
1431
+ CLIPTextEncode,
1432
+ )
1433
+
1434
+ # Initialize all constant nodes and models in global context
1435
+ import_custom_nodes()
1436
+
1437
+ # Global variables for preloaded models and constants
1438
+ #with torch.inference_mode():
1439
+ # Initialize constants
1440
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
1441
+ CONST_1024 = intconstant.get_value(value=1024)
1442
+
1443
+ # Load CLIP
1444
+ dualcliploader = DualCLIPLoader()
1445
+ CLIP_MODEL = dualcliploader.load_clip(
1446
+ clip_name1="t5/t5xxl_fp16.safetensors",
1447
+ clip_name2="clip_l.safetensors",
1448
+ type="flux",
1449
+ )
1450
+
1451
+ # Load VAE
1452
+ vaeloader = VAELoader()
1453
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
1454
+
1455
+ # Load UNET
1456
+ unetloader = UNETLoader()
1457
+ UNET_MODEL = unetloader.load_unet(
1458
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
1459
+ )
1460
+
1461
+ # Load CLIP Vision
1462
+ clipvisionloader = CLIPVisionLoader()
1463
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
1464
+ clip_name="sigclip_vision_patch14_384.safetensors"
1465
+ )
1466
+
1467
+ # Load Style Model
1468
+ stylemodelloader = StyleModelLoader()
1469
+ STYLE_MODEL = stylemodelloader.load_style_model(
1470
+ style_model_name="flux1-redux-dev.safetensors"
1471
+ )
1472
+
1473
+ # Initialize samplers
1474
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
1475
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
1476
+
1477
+ # Initialize depth model
1478
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
1479
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
1480
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
1481
+ model="depth_anything_v2_vitl_fp32.safetensors"
1482
+ )
1483
+ cliptextencode = CLIPTextEncode()
1484
+ loadimage = LoadImage()
1485
+ vaeencode = VAEEncode()
1486
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
1487
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
1488
+ clipvisionencode = CLIPVisionEncode()
1489
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
1490
+ emptylatentimage = EmptyLatentImage()
1491
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
1492
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
1493
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
1494
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
1495
+ vaedecode = VAEDecode()
1496
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
1497
+ saveimage = SaveImage()
1498
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
1499
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
1500
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
1501
+
1502
+ @spaces.GPU
1503
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
1504
+ """Main generation function that processes inputs and returns the path to the generated image."""
1505
+
1506
+ clear_memory()
1507
+ with torch.inference_mode():
1508
+ # Set up CLIP
1509
+ clip_switch = cr_clip_input_switch.switch(
1510
+ Input=1,
1511
+ clip1=get_value_at_index(CLIP_MODEL, 0),
1512
+ clip2=get_value_at_index(CLIP_MODEL, 0),
1513
+ )
1514
+
1515
+ # Encode text
1516
+ text_encoded = cliptextencode.encode(
1517
+ text=prompt,
1518
+ clip=get_value_at_index(clip_switch, 0),
1519
+ )
1520
+ empty_text = cliptextencode.encode(
1521
+ text="",
1522
+ clip=get_value_at_index(clip_switch, 0),
1523
+ )
1524
+
1525
+ # Process structure image
1526
+ structure_img = loadimage.load_image(image=structure_image)
1527
+
1528
+ # Resize image
1529
+ resized_img = imageresize.execute(
1530
+ width=get_value_at_index(CONST_1024, 0),
1531
+ height=get_value_at_index(CONST_1024, 0),
1532
+ interpolation="bicubic",
1533
+ method="keep proportion",
1534
+ condition="always",
1535
+ multiple_of=16,
1536
+ image=get_value_at_index(structure_img, 0),
1537
+ )
1538
+
1539
+ # Get image size
1540
+ size_info = getimagesizeandcount.getsize(
1541
+ image=get_value_at_index(resized_img, 0)
1542
+ )
1543
+
1544
+ # Encode VAE
1545
+ vae_encoded = vaeencode.encode(
1546
+ pixels=get_value_at_index(size_info, 0),
1547
+ vae=get_value_at_index(VAE_MODEL, 0),
1548
+ )
1549
+
1550
+ # Process depth
1551
+ depth_processed = depthanything_v2.process(
1552
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
1553
+ images=get_value_at_index(size_info, 0),
1554
+ )
1555
+
1556
+ # Apply Flux guidance
1557
+ flux_guided = fluxguidance.append(
1558
+ guidance=depth_strength,
1559
+ conditioning=get_value_at_index(text_encoded, 0),
1560
+ )
1561
+
1562
+ # Process style image
1563
+ style_img = loadimage.load_image(image=style_image)
1564
+
1565
+ # Encode style with CLIP Vision
1566
+ style_encoded = clipvisionencode.encode(
1567
+ crop="center",
1568
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
1569
+ image=get_value_at_index(style_img, 0),
1570
+ )
1571
+
1572
+ # Set up conditioning
1573
+ conditioning = instructpixtopixconditioning.encode(
1574
+ positive=get_value_at_index(flux_guided, 0),
1575
+ negative=get_value_at_index(empty_text, 0),
1576
+ vae=get_value_at_index(VAE_MODEL, 0),
1577
+ pixels=get_value_at_index(depth_processed, 0),
1578
+ )
1579
+
1580
+ # Apply style
1581
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
1582
+ strength=style_strength,
1583
+ conditioning=get_value_at_index(conditioning, 0),
1584
+ style_model=get_value_at_index(STYLE_MODEL, 0),
1585
+ clip_vision_output=get_value_at_index(style_encoded, 0),
1586
+ )
1587
+
1588
+ # Set up empty latent
1589
+ empty_latent = emptylatentimage.generate(
1590
+ width=get_value_at_index(resized_img, 1),
1591
+ height=get_value_at_index(resized_img, 2),
1592
+ batch_size=1,
1593
+ )
1594
+
1595
+ # Set up guidance
1596
+ guided = basicguider.get_guider(
1597
+ model=get_value_at_index(UNET_MODEL, 0),
1598
+ conditioning=get_value_at_index(style_applied, 0),
1599
+ )
1600
+
1601
+ # Set up scheduler
1602
+ schedule = basicscheduler.get_sigmas(
1603
+ scheduler="simple",
1604
+ steps=28,
1605
+ denoise=1,
1606
+ model=get_value_at_index(UNET_MODEL, 0),
1607
+ )
1608
+
1609
+ # Generate random noise
1610
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
1611
+
1612
+ # Sample
1613
+ sampled = samplercustomadvanced.sample(
1614
+ noise=get_value_at_index(noise, 0),
1615
+ guider=get_value_at_index(guided, 0),
1616
+ sampler=get_value_at_index(SAMPLER, 0),
1617
+ sigmas=get_value_at_index(schedule, 0),
1618
+ latent_image=get_value_at_index(empty_latent, 0),
1619
+ )
1620
+
1621
+ # Decode VAE
1622
+ decoded = vaedecode.decode(
1623
+ samples=get_value_at_index(sampled, 0),
1624
+ vae=get_value_at_index(VAE_MODEL, 0),
1625
+ )
1626
+
1627
+ # Save image
1628
+ prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
1629
+
1630
+ saved = saveimage.save_images(
1631
+ filename_prefix=get_value_at_index(prefix, 0),
1632
+ images=get_value_at_index(decoded, 0),
1633
+ )
1634
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
1635
+
1636
+ clear_memory()
1637
+ return saved_path
1638
+
1639
+ # Create Gradio interface
1640
+
1641
+ examples = [
1642
+ ["", "chair_input_1.jpg", "chair_input_2.png", 15, 0.5],
1643
+ ]
1644
+
1645
+ output_image = gr.Image(label="Generated Image")
1646
+
1647
+ with gr.Blocks() as app:
1648
+ with gr.Tab("Relighting"):
1649
+ with gr.Row():
1650
+ gr.Markdown("## Product Placement from Text")
1651
+ with gr.Row():
1652
+ with gr.Column():
1653
+ with gr.Row():
1654
+ input_fg = gr.Image(type="pil", label="Image", height=480)
1655
+ with gr.Row():
1656
+ with gr.Group():
1657
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1658
+ text_prompt = gr.Textbox(
1659
+ label="Text Prompt",
1660
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1661
+ value=""
1662
+ )
1663
+ extract_button = gr.Button(value="Remove Background")
1664
+ with gr.Row():
1665
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1666
+ extracted_fg = gr.Image(type="pil", label="Extracted Foreground", height=480)
1667
+ angles_fg = gr.Image(type="pil", label="Converted Foreground", height=480, visible=False)
1668
+
1669
+
1670
+
1671
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
1672
+ with gr.Group():
1673
+ run_button = gr.Button("Generate alternative angles")
1674
+ orientation_result = gr.Gallery(
1675
+ label="Result",
1676
+ show_label=False,
1677
+ columns=[3],
1678
+ rows=[2],
1679
+ object_fit="fill",
1680
+ height="auto",
1681
+ allow_preview=False,
1682
+ )
1683
+
1684
+ if orientation_result:
1685
+ orientation_result.select(use_orientation, inputs=None, outputs=extracted_fg)
1686
+
1687
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result', type='pil')
1688
+
1689
+
1690
+ with gr.Column():
1691
+
1692
+ with gr.Row():
1693
+ with gr.Column(4):
1694
+ result_gallery = gr.Gallery(height=832, label='Outputs', object_fit='contain', selected_index=0)
1695
+ if result_gallery:
1696
+ result_gallery.select(use_orientation, inputs=None, outputs=dummy_image_for_outputs)
1697
+ with gr.Column(1):
1698
+ with gr.Group():
1699
+ gr.Markdown("Outpaint")
1700
+ with gr.Row():
1701
+ with gr.Column(scale=2):
1702
+ prompt_fill = gr.Textbox(label="Prompt (Optional)")
1703
+ with gr.Column(scale=1):
1704
+ fill_button = gr.Button("Generate")
1705
+ target_ratio = gr.Radio(
1706
+ label="Image Ratio",
1707
+ choices=["9:16", "16:9", "1:1", "Custom"],
1708
+ value="9:16",
1709
+ scale=3
1710
+ )
1711
+ alignment_dropdown = gr.Dropdown(
1712
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
1713
+ value="Middle",
1714
+ label="Alignment",
1715
+ )
1716
+ resize_option = gr.Radio(
1717
+ label="Resize input image",
1718
+ choices=["Full", "75%", "50%", "33%", "25%", "Custom"],
1719
+ value="75%"
1720
+ )
1721
+ custom_resize_percentage = gr.Slider(
1722
+ label="Custom resize (%)",
1723
+ minimum=1,
1724
+ maximum=100,
1725
+ step=1,
1726
+ value=50,
1727
+ visible=False
1728
+ )
1729
+
1730
+ fill_result = gr.Image(
1731
+ interactive=False,
1732
+ label="Generated Image",
1733
+ )
1734
+
1735
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
1736
+ with gr.Column():
1737
+ with gr.Row():
1738
+ width_slider = gr.Slider(
1739
+ label="Target Width",
1740
+ minimum=720,
1741
+ maximum=1536,
1742
+ step=8,
1743
+ value=720,
1744
+ )
1745
+ height_slider = gr.Slider(
1746
+ label="Target Height",
1747
+ minimum=720,
1748
+ maximum=1536,
1749
+ step=8,
1750
+ value=1280,
1751
+ )
1752
+
1753
+ num_inference_steps = gr.Slider(label="Steps", minimum=2, maximum=50, step=1, value=18)
1754
+ with gr.Group():
1755
+ overlap_percentage = gr.Slider(
1756
+ label="Mask overlap (%)",
1757
+ minimum=1,
1758
+ maximum=50,
1759
+ value=10,
1760
+ step=1
1761
+ )
1762
+ with gr.Row():
1763
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
1764
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
1765
+ with gr.Row():
1766
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
1767
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
1768
+
1769
+
1770
+ with gr.Row():
1771
+ with gr.Group():
1772
+ prompt = gr.Textbox(label="Prompt")
1773
+ bg_source = gr.Radio(choices=[e.value for e in list(BGSource)[2:]],
1774
+ value=BGSource.LEFT.value,
1775
+ label="Lighting Preference (Initial Latent)", type='value')
1776
+
1777
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
1778
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
1779
+ with gr.Row():
1780
+ relight_button = gr.Button(value="Relight")
1781
+
1782
+ with gr.Group(visible=False):
1783
+ with gr.Row():
1784
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1785
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1786
+
1787
+ with gr.Row():
1788
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1789
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1790
+
1791
+ with gr.Accordion("Advanced options", open=False):
1792
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
1793
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01, visible=False)
1794
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
1795
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
1796
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
1797
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality', visible=False)
1798
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality', visible=False)
1799
+ x_slider = gr.Slider(
1800
+ minimum=0,
1801
+ maximum=1000,
1802
+ label="X Position",
1803
+ value=500,
1804
+ visible=False
1805
+ )
1806
+ y_slider = gr.Slider(
1807
+ minimum=0,
1808
+ maximum=1000,
1809
+ label="Y Position",
1810
+ value=500,
1811
+ visible=False
1812
+ )
1813
+
1814
+ # with gr.Row():
1815
+
1816
+ # gr.Examples(
1817
+ # fn=lambda *args: ([args[-1]], None),
1818
+ # examples=db_examples.foreground_conditioned_examples,
1819
+ # inputs=[
1820
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
1821
+ # ],
1822
+ # outputs=[result_gallery, output_bg],
1823
+ # run_on_click=True, examples_per_page=1024
1824
+ # )
1825
+ ips = [extracted_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
1826
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery]).then(clear_memory, inputs=[], outputs=[])
1827
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1828
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1829
+
1830
+ # def use_output_as_input(output_image):
1831
+ # return output_image
1832
+
1833
+ # use_as_input_button.click(
1834
+ # fn=use_output_as_input,
1835
+ # inputs=[fill_result],
1836
+ # outputs=[input_image]
1837
+ # )
1838
+
1839
+ target_ratio.change(
1840
+ fn=preload_presets,
1841
+ inputs=[target_ratio, width_slider, height_slider],
1842
+ outputs=[width_slider, height_slider, settings_panel],
1843
+ queue=False
1844
+ )
1845
+
1846
+ width_slider.change(
1847
+ fn=select_the_right_preset,
1848
+ inputs=[width_slider, height_slider],
1849
+ outputs=[target_ratio],
1850
+ queue=False
1851
+ )
1852
+
1853
+ height_slider.change(
1854
+ fn=select_the_right_preset,
1855
+ inputs=[width_slider, height_slider],
1856
+ outputs=[target_ratio],
1857
+ queue=False
1858
+ )
1859
+
1860
+ resize_option.change(
1861
+ fn=toggle_custom_resize_slider,
1862
+ inputs=[resize_option],
1863
+ outputs=[custom_resize_percentage],
1864
+ queue=False
1865
+ )
1866
+
1867
+ fill_button.click(
1868
+ fn=clear_result,
1869
+ inputs=None,
1870
+ outputs=fill_result,
1871
+ ).then(
1872
+ fn=inpaint,
1873
+ inputs=[dummy_image_for_outputs, width_slider, height_slider, overlap_percentage, num_inference_steps,
1874
+ resize_option, custom_resize_percentage, prompt_fill, alignment_dropdown,
1875
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
1876
+ outputs=[fill_result]).then(clear_memory, inputs=[], outputs=[])
1877
+ # ).then(
1878
+ # fn=lambda: gr.update(visible=True),
1879
+ # inputs=None,
1880
+ # outputs=use_as_input_button,
1881
+ # )
1882
+
1883
+ prompt_fill.submit(
1884
+ fn=clear_result,
1885
+ inputs=None,
1886
+ outputs=fill_result,
1887
+ ).then(
1888
+ fn=inpaint,
1889
+ inputs=[dummy_image_for_outputs, width_slider, height_slider, overlap_percentage, num_inference_steps,
1890
+ resize_option, custom_resize_percentage, prompt_fill, alignment_dropdown,
1891
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
1892
+ outputs=[fill_result]).then(clear_memory, inputs=[], outputs=[])
1893
+
1894
+ def convert_to_pil(image):
1895
+ try:
1896
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
1897
+ image = image.astype(np.uint8)
1898
+ logging.info(f"Converted image shape: {image.shape}, dtype: {image.dtype}")
1899
+ return image
1900
+ except Exception as e:
1901
+ logging.error(f"Error converting image: {e}")
1902
+ return image
1903
+
1904
+ run_button.click(
1905
+ fn=convert_to_pil,
1906
+ inputs=extracted_fg, # This is already RGBA with removed background
1907
+ outputs=angles_fg
1908
+ ).then(
1909
+ fn=infer,
1910
+ inputs=[
1911
+ text_prompt,
1912
+ extracted_fg, # Already processed RGBA image
1913
+ ],
1914
+ outputs=[orientation_result],
1915
+ ).then(clear_memory, inputs=[], outputs=[])
1916
+
1917
+ find_objects_button.click(
1918
+ fn=process_image,
1919
+ inputs=[input_fg, text_prompt],
1920
+ outputs=[extracted_objects, extracted_fg]
1921
+ ).then(clear_memory, inputs=[], outputs=[])
1922
+
1923
+ extract_button.click(
1924
+ fn=extract_foreground,
1925
+ inputs=[input_fg],
1926
+ outputs=[extracted_fg, x_slider, y_slider]
1927
+ ).then(clear_memory, inputs=[], outputs=[])
1928
+
1929
+ with gr.Tab("Style Transfer"):
1930
+ gr.Markdown("## Apply the style of an image to another one")
1931
+ with gr.Row():
1932
+ with gr.Column():
1933
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
1934
+ with gr.Row():
1935
+ with gr.Group():
1936
+ structure_image = gr.Image(label="Structure Image", type="filepath")
1937
+ depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
1938
+ with gr.Group():
1939
+ style_image = gr.Image(label="Style Image", type="filepath")
1940
+ style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
1941
+ generate_btn = gr.Button("Generate")
1942
+
1943
+ gr.Examples(
1944
+ examples=examples,
1945
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1946
+ outputs=[output_image],
1947
+ fn=generate_image,
1948
+ cache_examples=True,
1949
+ cache_mode="lazy"
1950
+ )
1951
+
1952
+ with gr.Column():
1953
+ output_image.render()
1954
+ transfer_btn = gr.Button("Send to relight")
1955
+
1956
+ with gr.Tab("Caption"):
1957
+ with gr.Row():
1958
+ gr.Markdown("## Describe Image")
1959
+ with gr.Row():
1960
+ with gr.Column():
1961
+ with gr.Row():
1962
+ image_to_describe = gr.Image(type="pil", label="Image", height=480)
1963
+ with gr.Row():
1964
+ with gr.Group():
1965
+ with gr.Column():
1966
+ describe_button = gr.Button(value="Describe Image")
1967
+ text_to_describe = gr.Textbox(value="Describe object or scene")
1968
+ description_text = gr.Textbox(
1969
+ label="Output",
1970
+ placeholder="",
1971
+ value=""
1972
+ )
1973
+
1974
+
1975
+ def send_img(img_result):
1976
+ return img_result
1977
+
1978
+ transfer_btn.click(send_img, [output_image], [input_fg]).then(clear_memory, inputs=[], outputs=[])
1979
+
1980
+ # describe_button.click(describe_image, [image_to_describe], [description_text])
1981
+
1982
+ describe_button.click(
1983
+ fn=generate_description,
1984
+ inputs=[text_to_describe,image_to_describe],
1985
+ outputs=description_text
1986
+ ).then(clear_memory, inputs=[], outputs=[])
1987
+
1988
+ generate_btn.click(
1989
+ fn=generate_image,
1990
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1991
+ outputs=[output_image]
1992
+ ).then(clear_memory, inputs=[], outputs=[])
1993
+
1994
+ if __name__ == "__main__":
1995
+ app.launch(share=True)