multimodalart HF staff commited on
Commit
aaff709
β€’
1 Parent(s): 2ed4418

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -5,7 +5,9 @@ import logging
5
  import torch
6
  from PIL import Image
7
  import spaces
8
- from diffusers import DiffusionPipeline
 
 
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
  import copy
11
  import random
@@ -16,11 +18,18 @@ with open('loras.json', 'r') as f:
16
  loras = json.load(f)
17
 
18
  # Initialize the base model
 
 
19
  base_model = "black-forest-labs/FLUX.1-dev"
20
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
 
 
21
 
22
  MAX_SEED = 2**32-1
23
 
 
 
24
  class calculateDuration:
25
  def __init__(self, activity_name=""):
26
  self.activity_name = activity_name
@@ -61,10 +70,9 @@ def update_selection(evt: gr.SelectData, width, height):
61
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
62
  pipe.to("cuda")
63
  generator = torch.Generator(device="cuda").manual_seed(seed)
64
-
65
  with calculateDuration("Generating image"):
66
  # Generate image
67
- image = pipe(
68
  prompt=prompt_mash,
69
  num_inference_steps=steps,
70
  guidance_scale=cfg_scale,
@@ -72,13 +80,14 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
72
  height=height,
73
  generator=generator,
74
  joint_attention_kwargs={"scale": lora_scale},
75
- ).images[0]
76
- return image
 
 
77
 
78
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
79
  if selected_index is None:
80
  raise gr.Error("You must select a LoRA before proceeding.")
81
-
82
  selected_lora = loras[selected_index]
83
  lora_path = selected_lora["repo"]
84
  trigger_word = selected_lora["trigger_word"]
@@ -92,24 +101,31 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
92
  prompt_mash = f"{trigger_word} {prompt}"
93
  else:
94
  prompt_mash = prompt
 
95
  # Load LoRA weights
96
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
97
  if "weights" in selected_lora:
98
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
99
- #pipe.fuse_lora()
100
  else:
101
  pipe.load_lora_weights(lora_path)
102
- #pipe.fuse_lora()
103
  # Set random seed for reproducibility
104
  with calculateDuration("Randomizing seed"):
105
  if randomize_seed:
106
  seed = random.randint(0, MAX_SEED)
 
 
107
 
108
- image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
 
 
 
 
 
109
  pipe.to("cpu")
110
- #pipe.unfuse_lora()
111
  pipe.unload_lora_weights()
112
- return image, seed
 
113
 
114
  def get_huggingface_safetensors(link):
115
  split_link = link.split("/")
 
5
  import torch
6
  from PIL import Image
7
  import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
+
11
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  import copy
13
  import random
 
18
  loras = json.load(f)
19
 
20
  # Initialize the base model
21
+ dtype = torch.bfloat16
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
  base_model = "black-forest-labs/FLUX.1-dev"
24
+
25
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
26
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
27
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
28
 
29
  MAX_SEED = 2**32-1
30
 
31
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
32
+
33
  class calculateDuration:
34
  def __init__(self, activity_name=""):
35
  self.activity_name = activity_name
 
70
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
71
  pipe.to("cuda")
72
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
73
  with calculateDuration("Generating image"):
74
  # Generate image
75
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
76
  prompt=prompt_mash,
77
  num_inference_steps=steps,
78
  guidance_scale=cfg_scale,
 
80
  height=height,
81
  generator=generator,
82
  joint_attention_kwargs={"scale": lora_scale},
83
+ output_type="pil",
84
+ good_vae=good_vae,
85
+ ):
86
+ yield img
87
 
88
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
89
  if selected_index is None:
90
  raise gr.Error("You must select a LoRA before proceeding.")
 
91
  selected_lora = loras[selected_index]
92
  lora_path = selected_lora["repo"]
93
  trigger_word = selected_lora["trigger_word"]
 
101
  prompt_mash = f"{trigger_word} {prompt}"
102
  else:
103
  prompt_mash = prompt
104
+
105
  # Load LoRA weights
106
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
107
  if "weights" in selected_lora:
108
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
 
109
  else:
110
  pipe.load_lora_weights(lora_path)
111
+
112
  # Set random seed for reproducibility
113
  with calculateDuration("Randomizing seed"):
114
  if randomize_seed:
115
  seed = random.randint(0, MAX_SEED)
116
+
117
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
118
 
119
+ # Consume the generator to get the final image
120
+ final_image = None
121
+ for image in image_generator:
122
+ final_image = image
123
+ yield image, seed # Yield intermediate images and seed
124
+
125
  pipe.to("cpu")
 
126
  pipe.unload_lora_weights()
127
+
128
+ return final_image, seed # Return the final image and seed
129
 
130
  def get_huggingface_safetensors(link):
131
  split_link = link.split("/")