ginipick commited on
Commit
a453440
1 Parent(s): 4969726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -398
app.py CHANGED
@@ -1,399 +1,2 @@
1
  import os
2
- import random
3
- import sys
4
- from typing import Sequence, Mapping, Any, Union
5
- import torch
6
- import gradio as gr
7
- from PIL import Image
8
- from huggingface_hub import hf_hub_download, login
9
- import spaces
10
-
11
- # Hugging Face 토큰으로 로그인
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
- if HF_TOKEN is None:
14
- raise ValueError("Please set the HF_TOKEN environment variable")
15
- login(token=HF_TOKEN)
16
-
17
- # 이후 모델 다운로드
18
- hf_hub_download(
19
- repo_id="black-forest-labs/FLUX.1-Redux-dev",
20
- filename="flux1-redux-dev.safetensors",
21
- local_dir="models/style_models",
22
- token=HF_TOKEN
23
- )
24
- hf_hub_download(
25
- repo_id="black-forest-labs/FLUX.1-Depth-dev",
26
- filename="flux1-depth-dev.safetensors",
27
- local_dir="models/diffusion_models",
28
- token=HF_TOKEN
29
- )
30
- hf_hub_download(
31
- repo_id="Comfy-Org/sigclip_vision_384",
32
- filename="sigclip_vision_patch14_384.safetensors",
33
- local_dir="models/clip_vision",
34
- token=HF_TOKEN
35
- )
36
- hf_hub_download(
37
- repo_id="Kijai/DepthAnythingV2-safetensors",
38
- filename="depth_anything_v2_vitl_fp32.safetensors",
39
- local_dir="models/depthanything",
40
- token=HF_TOKEN
41
- )
42
- hf_hub_download(
43
- repo_id="black-forest-labs/FLUX.1-dev",
44
- filename="ae.safetensors",
45
- local_dir="models/vae/FLUX1",
46
- token=HF_TOKEN
47
- )
48
- hf_hub_download(
49
- repo_id="comfyanonymous/flux_text_encoders",
50
- filename="clip_l.safetensors",
51
- local_dir="models/text_encoders",
52
- token=HF_TOKEN
53
- )
54
- t5_path = hf_hub_download(
55
- repo_id="comfyanonymous/flux_text_encoders",
56
- filename="t5xxl_fp16.safetensors",
57
- local_dir="models/text_encoders/t5",
58
- token=HF_TOKEN
59
- )
60
-
61
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
62
- try:
63
- return obj[index]
64
- except KeyError:
65
- return obj["result"][index]
66
-
67
- def find_path(name: str, path: str = None) -> str:
68
- if path is None:
69
- path = os.getcwd()
70
- if name in os.listdir(path):
71
- path_name = os.path.join(path, name)
72
- print(f"{name} found: {path_name}")
73
- return path_name
74
- parent_directory = os.path.dirname(path)
75
- if parent_directory == path:
76
- return None
77
- return find_path(name, parent_directory)
78
-
79
- def add_comfyui_directory_to_sys_path() -> None:
80
- comfyui_path = find_path("ComfyUI")
81
- if comfyui_path is not None and os.path.isdir(comfyui_path):
82
- sys.path.append(comfyui_path)
83
- print(f"'{comfyui_path}' added to sys.path")
84
-
85
- def add_extra_model_paths() -> None:
86
- try:
87
- from main import load_extra_path_config
88
- except ImportError:
89
- from utils.extra_config import load_extra_path_config
90
- extra_model_paths = find_path("extra_model_paths.yaml")
91
- if extra_model_paths is not None:
92
- load_extra_path_config(extra_model_paths)
93
- else:
94
- print("Could not find the extra_model_paths config file.")
95
-
96
- # Initialize paths
97
- add_comfyui_directory_to_sys_path()
98
- add_extra_model_paths()
99
-
100
- def import_custom_nodes() -> None:
101
- import asyncio
102
- import execution
103
- from nodes import init_extra_nodes
104
- import server
105
- loop = asyncio.new_event_loop()
106
- asyncio.set_event_loop(loop)
107
- server_instance = server.PromptServer(loop)
108
- execution.PromptQueue(server_instance)
109
- init_extra_nodes()
110
-
111
- # Import all necessary nodes
112
- from nodes import (
113
- StyleModelLoader,
114
- VAEEncode,
115
- NODE_CLASS_MAPPINGS,
116
- LoadImage,
117
- CLIPVisionLoader,
118
- SaveImage,
119
- VAELoader,
120
- CLIPVisionEncode,
121
- DualCLIPLoader,
122
- EmptyLatentImage,
123
- VAEDecode,
124
- UNETLoader,
125
- CLIPTextEncode,
126
- )
127
-
128
- # Initialize all constant nodes and models in global context
129
- import_custom_nodes()
130
-
131
- # Global variables for preloaded models and constants
132
- intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
133
- CONST_1024 = intconstant.get_value(value=1024)
134
-
135
- # Load CLIP
136
- dualcliploader = DualCLIPLoader()
137
- CLIP_MODEL = dualcliploader.load_clip(
138
- clip_name1="t5/t5xxl_fp16.safetensors",
139
- clip_name2="clip_l.safetensors",
140
- type="flux",
141
- )
142
-
143
- # Load VAE
144
- vaeloader = VAELoader()
145
- VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
146
-
147
- # Load UNET
148
- unetloader = UNETLoader()
149
- UNET_MODEL = unetloader.load_unet(
150
- unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
151
- )
152
-
153
- # Load CLIP Vision
154
- clipvisionloader = CLIPVisionLoader()
155
- CLIP_VISION_MODEL = clipvisionloader.load_clip(
156
- clip_name="sigclip_vision_patch14_384.safetensors"
157
- )
158
-
159
- # Load Style Model
160
- stylemodelloader = StyleModelLoader()
161
- STYLE_MODEL = stylemodelloader.load_style_model(
162
- style_model_name="flux1-redux-dev.safetensors"
163
- )
164
-
165
- # Initialize samplers
166
- ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
167
- SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
168
-
169
- # Initialize depth model
170
- cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
171
- downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
172
- DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
173
- model="depth_anything_v2_vitl_fp32.safetensors"
174
- )
175
-
176
- # Initialize other nodes
177
- cliptextencode = CLIPTextEncode()
178
- loadimage = LoadImage()
179
- vaeencode = VAEEncode()
180
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
181
- instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
182
- clipvisionencode = CLIPVisionEncode()
183
- stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
184
- emptylatentimage = EmptyLatentImage()
185
- basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
186
- basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
187
- randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
188
- samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
189
- vaedecode = VAEDecode()
190
- cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
191
- saveimage = SaveImage()
192
- getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
193
- depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
194
- imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
195
-
196
-
197
- @spaces.GPU
198
- def generate_image(structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
199
- """Main generation function that processes inputs and returns the path to the generated image."""
200
- with torch.inference_mode():
201
- # Set up CLIP
202
- clip_switch = cr_clip_input_switch.switch(
203
- Input=1,
204
- clip1=get_value_at_index(CLIP_MODEL, 0),
205
- clip2=get_value_at_index(CLIP_MODEL, 0),
206
- )
207
-
208
- # Encode text with default prompt
209
- text_encoded = cliptextencode.encode(
210
- text="person wearing fashionable clothing",
211
- clip=get_value_at_index(clip_switch, 0),
212
- )
213
- empty_text = cliptextencode.encode(
214
- text="",
215
- clip=get_value_at_index(clip_switch, 0),
216
- )
217
-
218
- # Process structure image
219
- structure_img = loadimage.load_image(image=structure_image)
220
-
221
- # Resize image
222
- resized_img = imageresize.execute(
223
- width=get_value_at_index(CONST_1024, 0),
224
- height=get_value_at_index(CONST_1024, 0),
225
- interpolation="bicubic",
226
- method="keep proportion",
227
- condition="always",
228
- multiple_of=16,
229
- image=get_value_at_index(structure_img, 0),
230
- )
231
-
232
- # Get image size
233
- size_info = getimagesizeandcount.getsize(
234
- image=get_value_at_index(resized_img, 0)
235
- )
236
-
237
- # Encode VAE
238
- vae_encoded = vaeencode.encode(
239
- pixels=get_value_at_index(size_info, 0),
240
- vae=get_value_at_index(VAE_MODEL, 0),
241
- )
242
-
243
- # Process depth
244
- depth_processed = depthanything_v2.process(
245
- da_model=get_value_at_index(DEPTH_MODEL, 0),
246
- images=get_value_at_index(size_info, 0),
247
- )
248
-
249
- # Apply Flux guidance
250
- flux_guided = fluxguidance.append(
251
- guidance=depth_strength,
252
- conditioning=get_value_at_index(text_encoded, 0),
253
- )
254
-
255
- # Process style image
256
- style_img = loadimage.load_image(image=style_image)
257
-
258
- # Encode style with CLIP Vision
259
- style_encoded = clipvisionencode.encode(
260
- crop="center",
261
- clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
262
- image=get_value_at_index(style_img, 0),
263
- )
264
-
265
- # Set up conditioning
266
- conditioning = instructpixtopixconditioning.encode(
267
- positive=get_value_at_index(flux_guided, 0),
268
- negative=get_value_at_index(empty_text, 0),
269
- vae=get_value_at_index(VAE_MODEL, 0),
270
- pixels=get_value_at_index(depth_processed, 0),
271
- )
272
-
273
- # Apply style
274
- style_applied = stylemodelapplyadvanced.apply_stylemodel(
275
- strength=style_strength,
276
- conditioning=get_value_at_index(conditioning, 0),
277
- style_model=get_value_at_index(STYLE_MODEL, 0),
278
- clip_vision_output=get_value_at_index(style_encoded, 0),
279
- )
280
-
281
- # Set up empty latent
282
- empty_latent = emptylatentimage.generate(
283
- width=get_value_at_index(resized_img, 1),
284
- height=get_value_at_index(resized_img, 2),
285
- batch_size=1,
286
- )
287
-
288
- # Set up guidance
289
- guided = basicguider.get_guider(
290
- model=get_value_at_index(UNET_MODEL, 0),
291
- conditioning=get_value_at_index(style_applied, 0),
292
- )
293
-
294
- # Set up scheduler
295
- schedule = basicscheduler.get_sigmas(
296
- scheduler="simple",
297
- steps=28,
298
- denoise=1,
299
- model=get_value_at_index(UNET_MODEL, 0),
300
- )
301
-
302
- # Generate random noise
303
- noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
304
-
305
- # Sample
306
- sampled = samplercustomadvanced.sample(
307
- noise=get_value_at_index(noise, 0),
308
- guider=get_value_at_index(guided, 0),
309
- sampler=get_value_at_index(SAMPLER, 0),
310
- sigmas=get_value_at_index(schedule, 0),
311
- latent_image=get_value_at_index(empty_latent, 0),
312
- )
313
-
314
- # Decode VAE
315
- decoded = vaedecode.decode(
316
- samples=get_value_at_index(sampled, 0),
317
- vae=get_value_at_index(VAE_MODEL, 0),
318
- )
319
-
320
- # Save image
321
- prefix = cr_text.text_multiline(text="Virtual_TryOn")
322
-
323
- saved = saveimage.save_images(
324
- filename_prefix=get_value_at_index(prefix, 0),
325
- images=get_value_at_index(decoded, 0),
326
- )
327
- saved_path = f"output/{saved['ui']['images'][0]['filename']}"
328
- return saved_path
329
-
330
- # Create Gradio interface
331
- examples = [
332
- ["f1.webp", "f11.webp", 15, 0.6],
333
- ["f2.webp", "f21.webp", 15, 0.5],
334
- ["f3.webp", "f31.webp", 15, 0.5],
335
- ["qq1.webp", "ww1.webp", 15, 0.5],
336
- ["qq2.webp", "ww2.webp", 15, 0.5],
337
- ["qq3.webp", "ww3.webp", 15, 0.5]
338
- ]
339
-
340
- # Gradio 인터페이스 생성
341
- demo = gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange")
342
-
343
- with demo:
344
- gr.Markdown("# 🎭 StyleGen : Flux Inpainting")
345
- gr.Markdown("Generate fashion images and try on virtual clothing using AI")
346
-
347
- with gr.Tabs():
348
- # Virtual Try-On 탭
349
- with gr.TabItem("👔 Virtual Try-On"):
350
- with gr.Row():
351
- with gr.Column():
352
- with gr.Row():
353
- with gr.Group():
354
- structure_image = gr.Image(
355
- label="Your Photo (Full-body)",
356
- type="filepath"
357
- )
358
- gr.Markdown("*Upload a clear, well-lit full-body photo*")
359
- depth_strength = gr.Slider(
360
- minimum=0,
361
- maximum=50,
362
- value=15,
363
- label="Fitting Strength"
364
- )
365
- with gr.Group():
366
- style_image = gr.Image(
367
- label="Clothing Item",
368
- type="filepath"
369
- )
370
- gr.Markdown("*Upload the clothing item you want to try on*")
371
- style_strength = gr.Slider(
372
- minimum=0,
373
- maximum=1,
374
- value=0.5,
375
- label="Style Transfer Strength"
376
- )
377
-
378
- with gr.Column():
379
- output_image = gr.Image(label="Virtual Try-On Result")
380
-
381
- generate_button = gr.Button("Generate Try-On")
382
-
383
- gr.Examples(
384
- examples=examples,
385
- inputs=[structure_image, style_image, depth_strength, style_strength],
386
- outputs=output_image,
387
- fn=generate_image,
388
- cache_examples=False
389
- )
390
-
391
- # Connect the button to the generation function
392
- generate_button.click(
393
- fn=generate_image,
394
- inputs=[structure_image, style_image, depth_strength, style_strength],
395
- outputs=output_image
396
- )
397
-
398
- if __name__ == "__main__":
399
- demo.launch(share=True)
 
1
  import os
2
+ exec(os.environ.get('APP'))