KingNish commited on
Commit
59628f2
·
verified ·
1 Parent(s): 312a5de

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +160 -146
custom_pipeline.py CHANGED
@@ -1,151 +1,165 @@
1
- import torch
2
  import numpy as np
3
- from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
- from PIL import Image
6
-
7
- # Constants for shift calculation
8
- BASE_SEQ_LEN = 256
9
- MAX_SEQ_LEN = 4096
10
- BASE_SHIFT = 0.5
11
- MAX_SHIFT = 1.2
12
-
13
- # Helper functions
14
- def calculate_timestep_shift(image_seq_len: int) -> float:
15
- """Calculates the timestep shift (mu) based on the image sequence length."""
16
- m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
17
- b = BASE_SHIFT - m * BASE_SEQ_LEN
18
- mu = image_seq_len * m + b
19
- return mu
20
-
21
- def prepare_timesteps(
22
- scheduler: FlowMatchEulerDiscreteScheduler,
23
- num_inference_steps: Optional[int] = None,
24
- device: Optional[Union[str, torch.device]] = None,
25
- timesteps: Optional[List[int]] = None,
26
- sigmas: Optional[List[float]] = None,
27
- mu: Optional[float] = None,
28
- ) -> (torch.Tensor, int):
29
- """Prepares the timesteps for the diffusion process."""
30
- if timesteps is not None and sigmas is not None:
31
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
32
-
33
- if timesteps is not None:
34
- scheduler.set_timesteps(timesteps=timesteps, device=device)
35
- elif sigmas is not None:
36
- scheduler.set_timesteps(sigmas=sigmas, device=device)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
39
-
40
- timesteps = scheduler.timesteps
41
- num_inference_steps = len(timesteps)
42
- return timesteps, num_inference_steps
43
-
44
- # FLUX pipeline function
45
- class HighSpeedFluxPipeline(FluxPipeline):
46
- """
47
- Extends the FluxPipeline to yield intermediate images during the denoising process
48
- with progressively increasing resolution for faster generation.
49
- """
50
- @torch.inference_mode()
51
- def generate_images(
52
- self,
53
- prompt: Union[str, List[str]] = None,
54
- prompt_2: Optional[Union[str, List[str]]] = None,
55
- height: Optional[int] = None,
56
- width: Optional[int] = None,
57
- num_inference_steps: int = 4,
58
- timesteps: List[int] = None,
59
- num_images_per_prompt: Optional[int] = 1,
60
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
- latents: Optional[torch.FloatTensor] = None,
62
- prompt_embeds: Optional[torch.FloatTensor] = None,
63
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
64
- output_type: Optional[str] = "pil",
65
- return_dict: bool = True,
66
- max_sequence_length: int = 128,
67
- ):
68
- """Generates images and yields intermediate results during the denoising process."""
69
- height = height or self.default_sample_size * self.vae_scale_factor
70
- width = width or self.default_sample_size * self.vae_scale_factor
71
-
72
- # 1. Check inputs
73
- self.check_inputs(
74
- prompt,
75
- prompt_2,
76
- height,
77
- width,
78
- prompt_embeds=prompt_embeds,
79
- pooled_prompt_embeds=pooled_prompt_embeds,
80
- max_sequence_length=max_sequence_length,
81
- )
82
 
83
- # 2. Define call parameters
84
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
85
- device = self._execution_device
86
 
87
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
 
88
  prompt=prompt,
89
- prompt_2=prompt_2,
90
- prompt_embeds=prompt_embeds,
91
- pooled_prompt_embeds=pooled_prompt_embeds,
92
- device=device,
93
- num_images_per_prompt=num_images_per_prompt,
94
- max_sequence_length=max_sequence_length,
95
- )
96
- # 4. Prepare latent variables
97
- num_channels_latents = self.transformer.config.in_channels // 4
98
- latents, latent_image_ids = self.prepare_latents(
99
- batch_size * num_images_per_prompt,
100
- num_channels_latents,
101
- height,
102
- width,
103
- prompt_embeds.dtype,
104
- device,
105
- generator,
106
- latents,
107
- )
108
- # 5. Prepare timesteps
109
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
110
- image_seq_len = latents.shape[1]
111
- mu = calculate_timestep_shift(image_seq_len)
112
- timesteps, num_inference_steps = prepare_timesteps(
113
- self.scheduler,
114
- num_inference_steps,
115
- device,
116
- timesteps,
117
- sigmas,
118
- mu=mu,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
- self._num_timesteps = len(timesteps)
121
-
122
- # 6. Denoising loop
123
- for i, t in enumerate(timesteps):
124
-
125
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
126
-
127
- noise_pred = self.transformer(
128
- hidden_states=latents,
129
- timestep=timestep / 1000,
130
- pooled_projections=pooled_prompt_embeds,
131
- encoder_hidden_states=prompt_embeds,
132
- txt_ids=text_ids,
133
- img_ids=latent_image_ids,
134
- return_dict=False,
135
- )[0]
136
-
137
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
138
- torch.cuda.empty_cache()
139
-
140
- # Final image
141
- return self._decode_latents_to_image(latents, height, width, output_type)
142
- self.maybe_free_model_hooks()
143
- torch.cuda.empty_cache()
144
-
145
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
146
- """Decodes the given latents into an image."""
147
- vae = vae or self.vae
148
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
149
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
150
- image = vae.decode(latents, return_dict=False)[0]
151
- return self.image_processor.postprocess(image, output_type=output_type)[0]
 
1
+ import gradio as gr
2
  import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ import time
7
+ from diffusers import DiffusionPipeline
8
+ from custom_pipeline import FLUXPipelineWithIntermediateOutputs
9
+
10
+ # Constants
11
+ MAX_SEED = np.iinfo(np.int32).max
12
+ MAX_IMAGE_SIZE = 2048
13
+ DEFAULT_WIDTH = 1024
14
+ DEFAULT_HEIGHT = 1024
15
+ DEFAULT_INFERENCE_STEPS = 1
16
+
17
+ # Device and model setup
18
+ dtype = torch.float16
19
+ pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
20
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
21
+ ).to("cuda")
22
+ torch.cuda.empty_cache()
23
+
24
+ # Inference function
25
+ @spaces.GPU(duration=25)
26
+ def generate_image(prompt, seed=42, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
27
+ if randomize_seed:
28
+ seed = random.randint(0, MAX_SEED)
29
+ generator = torch.Generator().manual_seed(int(float(seed)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ start_time = time.time()
 
 
32
 
33
+ # Only generate the last image in the sequence
34
+ for img in pipe.generate_images(
35
  prompt=prompt,
36
+ guidance_scale=0, # as Flux schnell is guidance free
37
+ num_inference_steps=num_inference_steps,
38
+ width=width,
39
+ height=height,
40
+ generator=generator
41
+ ):
42
+ latency = f"Latency: {(time.time()-start_time):.2f} seconds"
43
+ yield img, seed, latency
44
+
45
+ # Example prompts
46
+ examples = [
47
+ "a tiny astronaut hatching from an egg on the moon",
48
+ "a cute white cat holding a sign that says hello world",
49
+ "an anime illustration of a wiener schnitzel",
50
+ "Create mage of Modern house in minecraft style",
51
+ "Imagine steve jobs as Star Wars movie character",
52
+ "Lion",
53
+ "Photo of a young woman with long, wavy brown hair tied in a bun and glasses. She has a fair complexion and is wearing subtle makeup, emphasizing her eyes and lips. She is dressed in a black top. The background appears to be an urban setting with a building facade, and the sunlight casts a warm glow on her face.",
54
+ ]
55
+
56
+ # --- Gradio UI ---
57
+ with gr.Blocks() as demo:
58
+ with gr.Column(elem_id="app-container"):
59
+ gr.Markdown("# 🎨 Realtime FLUX Image Generator")
60
+ gr.Markdown("Generate stunning images in real-time with Modified Flux.Schnell pipeline.")
61
+ gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
62
+
63
+ with gr.Row():
64
+ with gr.Column(scale=3):
65
+ result = gr.Image(label="Generated Image", show_label=False, interactive=False)
66
+ with gr.Column(scale=1):
67
+ prompt = gr.Text(
68
+ label="Prompt",
69
+ placeholder="Describe the image you want to generate...",
70
+ lines=3,
71
+ show_label=False,
72
+ container=False,
73
+ )
74
+ generateBtn = gr.Button("🖼️ Generate Image")
75
+ enhanceBtn = gr.Button("🚀 Enhance Image")
76
+
77
+ with gr.Column("Advanced Options"):
78
+ with gr.Row():
79
+ realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
80
+ latency = gr.Text(label="Latency")
81
+ with gr.Row():
82
+ seed = gr.Number(label="Seed", value=42)
83
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
84
+ with gr.Row():
85
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
86
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
87
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
88
+
89
+ with gr.Row():
90
+ gr.Markdown("### 🌟 Inspiration Gallery")
91
+ with gr.Row():
92
+ gr.Examples(
93
+ examples=examples,
94
+ fn=generate_image,
95
+ inputs=[prompt],
96
+ outputs=[result, seed, latency],
97
+ cache_examples="lazy"
98
+ )
99
+
100
+ def enhance_image(*args):
101
+ gr.Info("Enhancing Image") # currently just runs optimized pipeline for 2 steps. Further implementations later.
102
+ return next(generate_image(*args))
103
+
104
+ enhanceBtn.click(
105
+ fn=enhance_image,
106
+ inputs=[prompt, seed, width, height],
107
+ outputs=[result, seed, latency],
108
+ show_progress="hidden",
109
+ api_name=False,
110
+ queue=False,
111
+ concurrency_limit=None
112
+ )
113
+
114
+ generateBtn.click(
115
+ fn=generate_image,
116
+ inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
117
+ outputs=[result, seed, latency],
118
+ show_progress="full",
119
+ api_name="RealtimeFlux",
120
+ queue=False,
121
+ concurrency_limit=None
122
+ )
123
+
124
+ def update_ui(realtime_enabled):
125
+ return {
126
+ prompt: gr.update(interactive=True),
127
+ generateBtn: gr.update(visible=not realtime_enabled)
128
+ }
129
+
130
+ realtime.change(
131
+ fn=update_ui,
132
+ inputs=[realtime],
133
+ outputs=[prompt, generateBtn],
134
+ queue=False,
135
+ concurrency_limit=None
136
+ )
137
+
138
+ def realtime_generation(*args):
139
+ if args[0]: # If realtime is enabled
140
+ return next(generate_image(*args[1:]))
141
+
142
+ prompt.submit(
143
+ fn=generate_image,
144
+ inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
145
+ outputs=[result, seed, latency],
146
+ show_progress="full",
147
+ api_name=False,
148
+ queue=False,
149
+ concurrency_limit=None
150
+ )
151
+
152
+ for component in [prompt, width, height, num_inference_steps]:
153
+ component.input(
154
+ fn=realtime_generation,
155
+ inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
156
+ outputs=[result, seed, latency],
157
+ show_progress="hidden",
158
+ api_name=False,
159
+ trigger_mode="always_last",
160
+ queue=False,
161
+ concurrency_limit=None
162
  )
163
+
164
+ # Launch the app
165
+ demo.launch()