prithivMLmods's picture
Update app.py
68aa420 verified
raw
history blame
17 kB
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
import random
import uuid
from typing import Tuple, Union, List, Optional, Any, Dict
import numpy as np
import time
import zipfile
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
# Description for the app
DESCRIPTION = """## flux-krea vs qwen"""
# Helper functions
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Load pipelines
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Flux.1-krea pipeline
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype).to(device)
pipe_krea = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=taef1).to(device)
# Qwen/Qwen-Image pipeline
pipe_qwen = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
# Define custom flux_pipe_call for Flux.1-krea
@torch.inference_mode()
def flux_pipe_call_that_returns_an_iterable_of_images(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
max_sequence_length: int = 512,
good_vae: Optional[Any] = None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
self._num_timesteps = len(timesteps)
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents_for_image, return_dict=False)[0]
yield self.image_processor.postprocess(image, output_type=output_type)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
torch.cuda.empty_cache()
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
image = good_vae.decode(latents, return_dict=False)[0]
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
yield self.image_processor.postprocess(image, output_type=output_type)[0]
pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe_krea)
# Helper functions for Flux.1-krea
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Aspect ratios
aspect_ratios = {
"1:1": (1328, 1328),
"16:9": (1664, 928),
"9:16": (928, 1664),
"4:3": (1472, 1140),
"3:4": (1140, 1472)
}
# Generation function for Flux.1-krea
@spaces.GPU
def generate_krea(
prompt: str,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 4.5,
randomize_seed: bool = False,
num_inference_steps: int = 28,
num_images: int = 1,
zip_images: bool = False,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device).manual_seed(seed)
start_time = time.time()
images = []
for _ in range(num_images):
final_img = list(pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil",
good_vae=good_vae,
))[-1] # Take the final image only
images.append(final_img)
end_time = time.time()
duration = end_time - start_time
image_paths = [save_image(img) for img in images]
zip_path = None
if zip_images:
zip_name = str(uuid.uuid4()) + ".zip"
with zipfile.ZipFile(zip_name, 'w') as zipf:
for i, img_path in enumerate(image_paths):
zipf.write(img_path, arcname=f"Img_{i}.png")
zip_path = zip_name
return image_paths, seed, f"{duration:.2f}", zip_path
# Generation function for Qwen/Qwen-Image
@spaces.GPU
def generate_qwen(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 4.0,
randomize_seed: bool = False,
num_inference_steps: int = 50,
num_images: int = 1,
zip_images: bool = False,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device).manual_seed(seed)
start_time = time.time()
images = pipe_qwen(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
output_type="pil",
).images
end_time = time.time()
duration = end_time - start_time
image_paths = [save_image(img) for img in images]
zip_path = None
if zip_images:
zip_name = str(uuid.uuid4()) + ".zip"
with zipfile.ZipFile(zip_name, 'w') as zipf:
for i, img_path in enumerate(image_paths):
zipf.write(img_path, arcname=f"Img_{i}.png")
zip_path = zip_name
return image_paths, seed, f"{duration:.2f}", zip_path
# Main generation function
@spaces.GPU
def generate(
model_choice: str,
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 3.5,
randomize_seed: bool = False,
num_inference_steps: int = 28,
num_images: int = 1,
zip_images: bool = False,
progress=gr.Progress(track_tqdm=True),
):
if model_choice == "Flux.1-krea":
return generate_krea(
prompt=prompt,
seed=seed,
width=width,
height=height,
guidance_scale=guidance_scale,
randomize_seed=randomize_seed,
num_inference_steps=num_inference_steps,
num_images=num_images,
zip_images=zip_images,
progress=progress,
)
elif model_choice == "Qwen Image":
final_negative_prompt = negative_prompt if use_negative_prompt else ""
return generate_qwen(
prompt=prompt,
negative_prompt=final_negative_prompt,
seed=seed,
width=width,
height=height,
guidance_scale=guidance_scale,
randomize_seed=randomize_seed,
num_inference_steps=num_inference_steps,
num_images=num_images,
zip_images=zip_images,
progress=progress,
)
else:
raise ValueError("Invalid model choice")
# Examples
examples = [
"An attractive young woman with blue eyes lying face down on the bed, light white and light amber, timeless beauty, sunrays shine upon it",
"Headshot of handsome young man, wearing dark gray sweater, brown hair and short beard, serious look, black background, soft studio lighting",
"A medium-angle shot of a young woman with long brown hair, wearing glasses, standing in front of purple and white lights",
"High-resolution photograph of a woman, photorealistic, vibrant colors"
]
css = '''
.gradio-container {
max-width: 590px !important;
margin: 0 auto !important;
}
h1 {
text-align: center;
}
footer {
visibility: hidden;
}
'''
# Gradio interface
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
with gr.Row():
model_choice = gr.Radio(
choices=["Flux.1-krea", "Qwen Image"],
label="Select Model",
value="Flux.1-krea"
)
with gr.Accordion("Additional Options", open=False):
aspect_ratio = gr.Dropdown(
label="Aspect Ratio",
choices=list(aspect_ratios.keys()),
value="1:1",
)
use_negative_prompt = gr.Checkbox(
label="Use negative prompt (Qwen Image only)",
value=False,
visible=False
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=2048,
step=64,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=2048,
step=64,
value=1024,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=28,
)
num_images = gr.Slider(
label="Number of images",
minimum=1,
maximum=5,
step=1,
value=1,
)
zip_images = gr.Checkbox(label="Zip generated images", value=False)
gr.Markdown("### Output Information")
seed_display = gr.Textbox(label="Seed used", interactive=False)
generation_time = gr.Textbox(label="Generation time (seconds)", interactive=False)
zip_file = gr.File(label="Download ZIP")
# Update aspect ratio
def set_dimensions(ar):
w, h = aspect_ratios[ar]
return gr.update(value=w), gr.update(value=h)
aspect_ratio.change(
fn=set_dimensions,
inputs=aspect_ratio,
outputs=[width, height]
)
# Update model-specific settings
def update_settings(mc):
if mc == "Flux.1-krea":
return (
gr.update(value=28),
gr.update(value=3.5),
gr.update(visible=False)
)
elif mc == "Qwen Image":
return (
gr.update(value=50),
gr.update(value=4.0),
gr.update(visible=True)
)
model_choice.change(
fn=update_settings,
inputs=model_choice,
outputs=[num_inference_steps, guidance_scale, use_negative_prompt]
)
# Negative prompt visibility
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt
)
# Run button and prompt submit
gr.on(
triggers=[prompt.submit, run_button.click],
fn=generate,
inputs=[
model_choice,
prompt,
negative_prompt,
use_negative_prompt,
seed,
width,
height,
guidance_scale,
randomize_seed,
num_inference_steps,
num_images,
zip_images,
],
outputs=[result, seed_display, generation_time, zip_file],
api_name="run",
)
# Examples
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result, seed_display, generation_time, zip_file],
fn=generate,
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(max_size=30).launch(mcp_server=True, ssr_mode=False, show_error=True)