John6666 commited on
Commit
fe9a742
·
verified ·
1 Parent(s): 184d241

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -32
app.py CHANGED
@@ -4,7 +4,8 @@ import json
4
  import logging
5
  import torch
6
  from PIL import Image
7
- from diffusers import DiffusionPipeline
 
8
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
  import copy
@@ -21,35 +22,48 @@ from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_
21
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
22
  from tagger.fl2flux import predict_tags_fl2_flux
23
 
 
 
 
 
 
 
 
24
  # Initialize the base model
25
  base_model = models[0]
26
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
27
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
28
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
 
29
  controlnet_union = None
30
  controlnet = None
31
  last_model = models[0]
32
  last_cn_on = False
33
 
 
 
 
 
34
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
35
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
 
36
  def change_base_model(repo_id: str, cn_on: bool):
37
  global pipe
38
  global controlnet_union
39
  global controlnet
40
  global last_model
41
  global last_cn_on
42
- dtype = torch.bfloat16
43
- #dtype = torch.float8_e4m3fn
44
  try:
45
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
46
  if cn_on:
47
  #progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
48
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
49
  clear_cache()
50
- controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
51
- controlnet = FluxMultiControlNetModel([controlnet_union])
52
- pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
 
53
  last_model = repo_id
54
  last_cn_on = cn_on
55
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
@@ -58,7 +72,8 @@ def change_base_model(repo_id: str, cn_on: bool):
58
  #progress(0, desc=f"Loading model: {repo_id}")
59
  print(f"Loading model: {repo_id}")
60
  clear_cache()
61
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
 
62
  last_model = repo_id
63
  last_cn_on = cn_on
64
  #progress(1, desc=f"Model loaded: {repo_id}")
@@ -70,12 +85,6 @@ def change_base_model(repo_id: str, cn_on: bool):
70
 
71
  change_base_model.zerogpu = True
72
 
73
- # Load LoRAs from JSON file
74
- with open('loras.json', 'r') as f:
75
- loras = json.load(f)
76
-
77
- MAX_SEED = 2**32-1
78
-
79
  class calculateDuration:
80
  def __init__(self, activity_name=""):
81
  self.activity_name = activity_name
@@ -118,9 +127,13 @@ def update_selection(evt: gr.SelectData, width, height):
118
  @spaces.GPU(duration=70)
119
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
120
  global pipe
 
 
121
  global controlnet
122
  global controlnet_union
123
  try:
 
 
124
  pipe.to("cuda")
125
  generator = torch.Generator(device="cuda").manual_seed(seed)
126
 
@@ -129,7 +142,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
129
  modes, images, scales = get_control_params()
130
  if not cn_on or len(modes) == 0:
131
  progress(0, desc="Start Inference.")
132
- image = pipe(
133
  prompt=prompt_mash,
134
  num_inference_steps=steps,
135
  guidance_scale=cfg_scale,
@@ -137,12 +150,15 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
137
  height=height,
138
  generator=generator,
139
  joint_attention_kwargs={"scale": lora_scale},
140
- ).images[0]
 
 
 
141
  else:
142
  progress(0, desc="Start Inference with ControlNet.")
143
  if controlnet is not None: controlnet.to("cuda")
144
  if controlnet_union is not None: controlnet_union.to("cuda")
145
- image = pipe(
146
  prompt=prompt_mash,
147
  control_image=images,
148
  control_mode=modes,
@@ -153,23 +169,35 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
153
  controlnet_conditioning_scale=scales,
154
  generator=generator,
155
  joint_attention_kwargs={"scale": lora_scale},
156
- ).images[0]
 
157
  except Exception as e:
158
  print(e)
159
  raise gr.Error(f"Inference Error: {e}")
160
- return image
161
 
162
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
163
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
164
  global pipe
 
 
 
 
165
  if selected_index is None and not is_valid_lora(lora_json):
166
  gr.Info("LoRA isn't selected.")
167
  # raise gr.Error("You must select a LoRA before proceeding.")
168
  progress(0, desc="Preparing Inference.")
169
 
 
 
 
 
 
 
 
170
  prompt_mash = prompt
171
  if is_valid_lora(lora_json):
172
- with calculateDuration("Loading LoRA weights"):
 
173
  fuse_loras(pipe, lora_json)
174
  trigger_word = get_trigger_word(lora_json)
175
  prompt_mash = f"{prompt} {trigger_word}"
@@ -200,17 +228,28 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
200
  seed = random.randint(0, MAX_SEED)
201
 
202
  progress(0, desc="Running Inference.")
203
-
204
- image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
205
- if is_valid_lora(lora_json):
206
- pipe.unfuse_lora()
207
- pipe.unload_lora_weights()
208
- if selected_index is not None: pipe.unload_lora_weights()
209
- pipe.to("cpu")
210
- if controlnet is not None: controlnet.to("cpu")
211
- if controlnet_union is not None: controlnet_union.to("cpu")
212
- clear_cache()
213
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def get_huggingface_safetensors(link):
216
  split_link = link.split("/")
@@ -343,6 +382,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
343
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
344
  deselect_lora_button = gr.Button("Deselect LoRA", variant="secondary")
345
  with gr.Column():
 
346
  result = gr.Image(label="Generated Image", format="png", show_share_button=False)
347
  with gr.Group():
348
  model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id to want to use.", choices=models, value=models[0], allow_custom_value=True)
@@ -450,7 +490,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
450
  fn=run_lora,
451
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
452
  lora_scale, lora_repo_json, cn_on],
453
- outputs=[result, seed],
454
  queue=True,
455
  show_api=True,
456
  )
 
4
  import logging
5
  import torch
6
  from PIL import Image
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
10
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
11
  import copy
 
22
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
23
  from tagger.fl2flux import predict_tags_fl2_flux
24
 
25
+ # Load LoRAs from JSON file
26
+ with open('loras.json', 'r') as f:
27
+ loras = json.load(f)
28
+
29
+ dtype = torch.bfloat16
30
+ #dtype = torch.float8_e4m3fn
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
  # Initialize the base model
33
  base_model = models[0]
34
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
35
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
36
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
38
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
39
  controlnet_union = None
40
  controlnet = None
41
  last_model = models[0]
42
  last_cn_on = False
43
 
44
+ MAX_SEED = 2**32-1
45
+
46
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
47
+
48
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
49
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
50
+ @spaces.GPU()
51
  def change_base_model(repo_id: str, cn_on: bool):
52
  global pipe
53
  global controlnet_union
54
  global controlnet
55
  global last_model
56
  global last_cn_on
 
 
57
  try:
58
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
59
  if cn_on:
60
  #progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
61
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
62
  clear_cache()
63
+ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype).to(device)
64
+ controlnet = FluxMultiControlNetModel([controlnet_union]).to(device)
65
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype).to(device)
66
+ #pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
67
  last_model = repo_id
68
  last_cn_on = cn_on
69
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
 
72
  #progress(0, desc=f"Loading model: {repo_id}")
73
  print(f"Loading model: {repo_id}")
74
  clear_cache()
75
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, vae=taef1).to(device)
76
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
77
  last_model = repo_id
78
  last_cn_on = cn_on
79
  #progress(1, desc=f"Model loaded: {repo_id}")
 
85
 
86
  change_base_model.zerogpu = True
87
 
 
 
 
 
 
 
88
  class calculateDuration:
89
  def __init__(self, activity_name=""):
90
  self.activity_name = activity_name
 
127
  @spaces.GPU(duration=70)
128
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
129
  global pipe
130
+ global taef1
131
+ global good_vae
132
  global controlnet
133
  global controlnet_union
134
  try:
135
+ good_vae.to("cuda")
136
+ taef1.to("cuda")
137
  pipe.to("cuda")
138
  generator = torch.Generator(device="cuda").manual_seed(seed)
139
 
 
142
  modes, images, scales = get_control_params()
143
  if not cn_on or len(modes) == 0:
144
  progress(0, desc="Start Inference.")
145
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
146
  prompt=prompt_mash,
147
  num_inference_steps=steps,
148
  guidance_scale=cfg_scale,
 
150
  height=height,
151
  generator=generator,
152
  joint_attention_kwargs={"scale": lora_scale},
153
+ output_type="pil",
154
+ good_vae=good_vae,
155
+ ):
156
+ yield img
157
  else:
158
  progress(0, desc="Start Inference with ControlNet.")
159
  if controlnet is not None: controlnet.to("cuda")
160
  if controlnet_union is not None: controlnet_union.to("cuda")
161
+ for img in pipe(
162
  prompt=prompt_mash,
163
  control_image=images,
164
  control_mode=modes,
 
169
  controlnet_conditioning_scale=scales,
170
  generator=generator,
171
  joint_attention_kwargs={"scale": lora_scale},
172
+ ).images:
173
+ yield img
174
  except Exception as e:
175
  print(e)
176
  raise gr.Error(f"Inference Error: {e}")
 
177
 
178
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
179
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
180
  global pipe
181
+ global taef1
182
+ global good_vae
183
+ global controlnet
184
+ global controlnet_union
185
  if selected_index is None and not is_valid_lora(lora_json):
186
  gr.Info("LoRA isn't selected.")
187
  # raise gr.Error("You must select a LoRA before proceeding.")
188
  progress(0, desc="Preparing Inference.")
189
 
190
+ with calculateDuration("Unloading LoRA"):
191
+ try:
192
+ pipe.unfuse_lora()
193
+ pipe.unload_lora_weights()
194
+ except Exception as e:
195
+ print(e)
196
+
197
  prompt_mash = prompt
198
  if is_valid_lora(lora_json):
199
+ # Load External LoRA weights
200
+ with calculateDuration("Loading External LoRA weights"):
201
  fuse_loras(pipe, lora_json)
202
  trigger_word = get_trigger_word(lora_json)
203
  prompt_mash = f"{prompt} {trigger_word}"
 
228
  seed = random.randint(0, MAX_SEED)
229
 
230
  progress(0, desc="Running Inference.")
231
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
232
+ # Consume the generator to get the final image
233
+ final_image = None
234
+ step_counter = 0
235
+ for image in image_generator:
236
+ step_counter+=1
237
+ final_image = image
238
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
239
+ yield image, seed, gr.update(value=progress_bar, visible=True)
240
+
241
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
242
+ #if is_valid_lora(lora_json):
243
+ # pipe.unfuse_lora()
244
+ # pipe.unload_lora_weights()
245
+ #if selected_index is not None: pipe.unload_lora_weights()
246
+ #pipe.to("cpu")
247
+ #good_vae.to("cpu")
248
+ #taef1.to("cpu")
249
+ #if controlnet is not None: controlnet.to("cpu")
250
+ #if controlnet_union is not None: controlnet_union.to("cpu")
251
+ #clear_cache()
252
+ #return final_image, seed # Return the final image and seed
253
 
254
  def get_huggingface_safetensors(link):
255
  split_link = link.split("/")
 
382
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
383
  deselect_lora_button = gr.Button("Deselect LoRA", variant="secondary")
384
  with gr.Column():
385
+ progress_bar = gr.Markdown(elem_id="progress",visible=False)
386
  result = gr.Image(label="Generated Image", format="png", show_share_button=False)
387
  with gr.Group():
388
  model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id to want to use.", choices=models, value=models[0], allow_custom_value=True)
 
490
  fn=run_lora,
491
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
492
  lora_scale, lora_repo_json, cn_on],
493
+ outputs=[result, seed, progress_bar],
494
  queue=True,
495
  show_api=True,
496
  )