multimodalart HF staff commited on
Commit
fd1c741
·
verified ·
1 Parent(s): fb702b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -63
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import gradio as gr
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
11
  hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
@@ -88,72 +89,73 @@ from nodes import (
88
  import_custom_nodes()
89
 
90
  # Global variables for preloaded models and constants
91
- with torch.inference_mode():
92
  # Initialize constants
93
- intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
94
- CONST_1024 = intconstant.get_value(value=1024)
95
-
96
- # Load CLIP
97
- dualcliploader = DualCLIPLoader()
98
- CLIP_MODEL = dualcliploader.load_clip(
99
- clip_name1="t5/t5xxl_fp16.safetensors",
100
- clip_name2="clip_l.safetensors",
101
- type="flux",
102
- )
103
-
104
- # Load VAE
105
- vaeloader = VAELoader()
106
- VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
107
-
108
- # Load UNET
109
- unetloader = UNETLoader()
110
- UNET_MODEL = unetloader.load_unet(
111
- unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
112
- )
113
-
114
- # Load CLIP Vision
115
- clipvisionloader = CLIPVisionLoader()
116
- CLIP_VISION_MODEL = clipvisionloader.load_clip(
117
- clip_name="sigclip_vision_patch14_384.safetensors"
118
- )
119
-
120
- # Load Style Model
121
- stylemodelloader = StyleModelLoader()
122
- STYLE_MODEL = stylemodelloader.load_style_model(
123
- style_model_name="flux1-redux-dev.safetensors"
124
- )
125
-
126
- # Initialize samplers
127
- ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
128
- SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
129
-
130
- # Initialize depth model
131
- cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
132
- downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
133
- DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
134
- model="depth_anything_v2_vitl_fp32.safetensors"
135
- )
136
- cliptextencode = CLIPTextEncode()
137
- loadimage = LoadImage()
138
- vaeencode = VAEEncode()
139
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
140
- instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
141
- clipvisionencode = CLIPVisionEncode()
142
- stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
143
- emptylatentimage = EmptyLatentImage()
144
- basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
145
- basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
146
- randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
147
- samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
148
- vaedecode = VAEDecode()
149
- cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
150
- saveimage = SaveImage()
151
- getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
152
- depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
153
- imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
 
 
154
  def generate_image(prompt: str, structure_image: str, depth_strength: float, style_image: str, style_strength: float, progress=gr.Progress(track_tqdm=True)) -> str:
155
  """Main generation function that processes inputs and returns the path to the generated image."""
156
-
157
  with torch.inference_mode():
158
  # Set up CLIP
159
  clip_switch = cr_clip_input_switch.switch(
 
6
  import gradio as gr
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
+ import spaces
10
 
11
  hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
12
  hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
 
89
  import_custom_nodes()
90
 
91
  # Global variables for preloaded models and constants
92
+ #with torch.inference_mode():
93
  # Initialize constants
94
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
95
+ CONST_1024 = intconstant.get_value(value=1024)
96
+
97
+ # Load CLIP
98
+ dualcliploader = DualCLIPLoader()
99
+ CLIP_MODEL = dualcliploader.load_clip(
100
+ clip_name1="t5/t5xxl_fp16.safetensors",
101
+ clip_name2="clip_l.safetensors",
102
+ type="flux",
103
+ )
104
+
105
+ # Load VAE
106
+ vaeloader = VAELoader()
107
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
108
+
109
+ # Load UNET
110
+ unetloader = UNETLoader()
111
+ UNET_MODEL = unetloader.load_unet(
112
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
113
+ )
114
+
115
+ # Load CLIP Vision
116
+ clipvisionloader = CLIPVisionLoader()
117
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
118
+ clip_name="sigclip_vision_patch14_384.safetensors"
119
+ )
120
+
121
+ # Load Style Model
122
+ stylemodelloader = StyleModelLoader()
123
+ STYLE_MODEL = stylemodelloader.load_style_model(
124
+ style_model_name="flux1-redux-dev.safetensors"
125
+ )
126
+
127
+ # Initialize samplers
128
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
129
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
130
+
131
+ # Initialize depth model
132
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
133
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
134
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
135
+ model="depth_anything_v2_vitl_fp32.safetensors"
136
+ )
137
+ cliptextencode = CLIPTextEncode()
138
+ loadimage = LoadImage()
139
+ vaeencode = VAEEncode()
140
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
141
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
142
+ clipvisionencode = CLIPVisionEncode()
143
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
144
+ emptylatentimage = EmptyLatentImage()
145
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
146
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
147
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
148
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
149
+ vaedecode = VAEDecode()
150
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
151
+ saveimage = SaveImage()
152
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
153
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
154
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
155
+
156
+ @spaces.GPU
157
  def generate_image(prompt: str, structure_image: str, depth_strength: float, style_image: str, style_strength: float, progress=gr.Progress(track_tqdm=True)) -> str:
158
  """Main generation function that processes inputs and returns the path to the generated image."""
 
159
  with torch.inference_mode():
160
  # Set up CLIP
161
  clip_switch = cr_clip_input_switch.switch(