import json import math import random import time from pathlib import Path from uuid import uuid4 import torch from diffusers import __version__ as diffusers_version from huggingface_hub import CommitOperationAdd, create_commit, create_repo from .upsampling import RealESRGANModel from .utils import pad_along_axis def get_all_files(root: Path): dirs = [root] while len(dirs) > 0: dir = dirs.pop() for candidate in dir.iterdir(): if candidate.is_file(): yield candidate if candidate.is_dir(): dirs.append(candidate) def get_groups_of_n(n: int, iterator): assert n > 1 buffer = [] for elt in iterator: if len(buffer) == n: yield buffer buffer = [] buffer.append(elt) if len(buffer) != 0: yield buffer def upload_folder_chunked( repo_id: str, upload_dir: Path, n: int = 100, private: bool = False, create_pr: bool = False, ): """Upload a folder to the Hugging Face Hub in chunks of n files at a time. Args: repo_id (str): The repo id to upload to. upload_dir (Path): The directory to upload. n (int, *optional*, defaults to 100): The number of files to upload at a time. private (bool, *optional*): Whether to upload the repo as private. create_pr (bool, *optional*): Whether to create a PR after uploading instead of commiting directly. """ url = create_repo(repo_id, exist_ok=True, private=private, repo_type="dataset") print(f"Uploading files to: {url}") root = Path(upload_dir) if not root.exists(): raise ValueError(f"Upload directory {root} does not exist.") for i, file_paths in enumerate(get_groups_of_n(n, get_all_files(root))): print(f"Committing {file_paths}") operations = [ CommitOperationAdd( path_in_repo=f"{file_path.parent.name}/{file_path.name}", path_or_fileobj=str(file_path), ) for file_path in file_paths ] create_commit( repo_id=repo_id, operations=operations, commit_message=f"Upload part {i}", repo_type="dataset", create_pr=create_pr, ) def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width): if len(prompts) != len(seeds): raise ValueError("Number of prompts and seeds must be equal.") embeds_batch, noise_batch = None, None batch_idx = 0 for i, (prompt, seed) in enumerate(zip(prompts, seeds)): embeds = pipeline.embed_text(prompt) noise = torch.randn( (1, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, generator=torch.Generator(device="cpu" if pipeline.device.type == "mps" else pipeline.device).manual_seed( seed ), ) embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds]) noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise]) batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts) if not batch_is_ready: continue yield batch_idx, embeds_batch.type(torch.cuda.HalfTensor), noise_batch.type(torch.cuda.HalfTensor) batch_idx += 1 del embeds_batch, noise_batch torch.cuda.empty_cache() embeds_batch, noise_batch = None, None def generate_images( pipeline, prompt, batch_size=1, num_batches=1, seeds=None, num_inference_steps=50, guidance_scale=7.5, output_dir="./images", image_file_ext=".jpg", upsample=False, height=512, width=512, eta=0.0, push_to_hub=False, repo_id=None, private=False, create_pr=False, name=None, ): """Generate images using the StableDiffusion pipeline. Args: pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance. prompt (str): The prompt to use for the image generation. batch_size (int, *optional*, defaults to 1): The batch size to use for image generation. num_batches (int, *optional*, defaults to 1): The number of batches to generate. seeds (list[int], *optional*): The seeds to use for the image generation. num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take. guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation. output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to. image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use. upsample (bool, *optional*, defaults to False): Whether to upsample the images. height (int, *optional*, defaults to 512): The height of the images to generate. width (int, *optional*, defaults to 512): The width of the images to generate. eta (float, *optional*, defaults to 0.0): The eta parameter to use for image generation. push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub. repo_id (str, *optional*): The repo id to push the images to. private (bool, *optional*): Whether to push the repo as private. create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly. name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of output_dir to save the images to. """ if push_to_hub: if repo_id is None: raise ValueError("Must provide repo_id if push_to_hub is True.") name = name or time.strftime("%Y%m%d-%H%M%S") save_path = Path(output_dir) / name save_path.mkdir(exist_ok=False, parents=True) prompt_config_path = save_path / "prompt_config.json" num_images = batch_size * num_batches seeds = seeds or [random.choice(list(range(0, 9999999))) for _ in range(num_images)] if len(seeds) != num_images: raise ValueError("Number of seeds must be equal to batch_size * num_batches.") if upsample: if getattr(pipeline, "upsampler", None) is None: pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") pipeline.upsampler.to(pipeline.device) cfg = dict( prompt=prompt, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, upsample=upsample, height=height, width=width, scheduler=dict(pipeline.scheduler.config), tiled=pipeline.tiled, diffusers_version=diffusers_version, device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown", ) prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False)) frame_index = 0 frame_filepaths = [] for batch_idx, embeds, noise in generate_input_batches( pipeline, [prompt] * num_images, seeds, batch_size, height, width ): print(f"Generating batch {batch_idx}") outputs = pipeline( text_embeddings=embeds, latents=noise, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, eta=eta, height=height, width=width, output_type="pil" if not upsample else "numpy", )["images"] if upsample: images = [] for output in outputs: images.append(pipeline.upsampler(output)) else: images = outputs for image in images: frame_filepath = save_path / f"{seeds[frame_index]}{image_file_ext}" image.save(frame_filepath) frame_filepaths.append(str(frame_filepath)) frame_index += 1 return frame_filepaths if push_to_hub: upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr) def generate_images_flax( pipeline, params, prompt, batch_size=1, num_batches=1, seeds=None, num_inference_steps=50, guidance_scale=7.5, output_dir="./images", image_file_ext=".jpg", upsample=False, height=512, width=512, push_to_hub=False, repo_id=None, private=False, create_pr=False, name=None, ): import jax from flax.training.common_utils import shard """Generate images using the StableDiffusion pipeline. Args: pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance. params (`Union[Dict, FrozenDict]`): The model parameters. prompt (str): The prompt to use for the image generation. batch_size (int, *optional*, defaults to 1): The batch size to use for image generation. num_batches (int, *optional*, defaults to 1): The number of batches to generate. seeds (int, *optional*): The seed to use for the image generation. num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take. guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation. output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to. image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use. upsample (bool, *optional*, defaults to False): Whether to upsample the images. height (int, *optional*, defaults to 512): The height of the images to generate. width (int, *optional*, defaults to 512): The width of the images to generate. push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub. repo_id (str, *optional*): The repo id to push the images to. private (bool, *optional*): Whether to push the repo as private. create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly. name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of output_dir to save the images to. """ if push_to_hub: if repo_id is None: raise ValueError("Must provide repo_id if push_to_hub is True.") name = name or time.strftime("%Y%m%d-%H%M%S") save_path = Path(output_dir) / name save_path.mkdir(exist_ok=False, parents=True) prompt_config_path = save_path / "prompt_config.json" num_images = batch_size * num_batches seeds = seeds or random.choice(list(range(0, 9999999))) prng_seed = jax.random.PRNGKey(seeds) if upsample: if getattr(pipeline, "upsampler", None) is None: pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") if not torch.cuda.is_available(): print("Upsampling is recommended to be done on a GPU, as it is very slow on CPU") else: pipeline.upsampler = pipeline.upsampler.cuda() cfg = dict( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, upsample=upsample, height=height, width=width, scheduler=dict(pipeline.scheduler.config), # tiled=pipeline.tiled, diffusers_version=diffusers_version, device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown", ) prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False)) NUM_TPU_CORES = jax.device_count() jit = True # force jit, assume params are already sharded batch_size_total = NUM_TPU_CORES * batch_size if jit else batch_size def generate_input_batches(prompts, batch_size): prompt_batch = None for batch_idx in range(math.ceil(len(prompts) / batch_size)): prompt_batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size] yield batch_idx, prompt_batch frame_index = 0 frame_filepaths = [] for batch_idx, prompt_batch in generate_input_batches([prompt] * num_images, batch_size_total): # This batch size correspond to each TPU core, so we are generating batch_size * NUM_TPU_CORES images print(f"Generating batches: {batch_idx*NUM_TPU_CORES} - {min((batch_idx+1)*NUM_TPU_CORES, num_batches)}") prompt_ids_batch = pipeline.prepare_inputs(prompt_batch) prng_seed_batch = prng_seed if jit: padded = False # Check if len of prompt_batch is multiple of NUM_TPU_CORES, if not pad its ids if len(prompt_batch) % NUM_TPU_CORES != 0: padded = True pad_size = NUM_TPU_CORES - (len(prompt_batch) % NUM_TPU_CORES) # Pad embeds_batch and noise_batch with zeros in batch dimension prompt_ids_batch = pad_along_axis(prompt_ids_batch, pad_size, axis=0) prompt_ids_batch = shard(prompt_ids_batch) prng_seed_batch = jax.random.split(prng_seed, jax.device_count()) outputs = pipeline( params, prng_seed=prng_seed_batch, prompt_ids=prompt_ids_batch, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, output_type="pil" if not upsample else "numpy", jit=jit, )["images"] if jit: # check if we padded and remove that padding from outputs if padded: outputs = outputs[:-pad_size] if upsample: images = [] for output in outputs: images.append(pipeline.upsampler(output)) else: images = outputs for image in images: uuid = str(uuid4()) frame_filepath = save_path / f"{uuid}{image_file_ext}" image.save(frame_filepath) frame_filepaths.append(str(frame_filepath)) frame_index += 1 return frame_filepaths if push_to_hub: upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)