""" Generates a set of new images to test the quality of the model. """ from hashlib import md5 from pathlib import Path import torch from diffusers import StableDiffusionPipeline from more_itertools import chunked from realfake.utils import Args, inject_args class GenerateImagesParams(Args): prompts_file: Path output_dir: Path num_images: int = 100 batch_size: int = 16 model_id: str = "CompVis/stable-diffusion-v1-4" device: str = "cuda" @inject_args def main(args: GenerateImagesParams) -> None: pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16) pipe.to(args.device) prompts = [p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p] batch_size = args.batch_size if args.num_images == 1: args.output_dir.mkdir(exist_ok=True, parents=True) count = 0 for chunk in chunked(prompts, batch_size): output = pipe(chunk, num_images_per_prompt=1) for image in output.images: image.save(args.output_dir/f"{count:04d}.png") count += 1 with (args.output_dir/"prompts.txt").open("w") as fp: fp.write("\n".join(prompts)) else: for prompt in prompts: prompt_signature = md5(prompt.encode("utf-8")).hexdigest() prompt_dir = args.output_dir/prompt_signature prompt_dir.mkdir(exist_ok=True, parents=True) (prompt_dir/"prompt.txt").write_text(prompt) left_images = args.num_images count = 0 while left_images > 0: output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images)) for image in output.images: image.save(prompt_dir/f"{count:04d}.png") count += 1 left_images -= batch_size if __name__ == "__main__": main()