realfake / realfake /bin /generate_images.py
devforfu
Generating one image per prompt
2903d34
"""
Generates a set of new images to test the quality of the model.
"""
from hashlib import md5
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline
from more_itertools import chunked
from realfake.utils import Args, inject_args
class GenerateImagesParams(Args):
prompts_file: Path
output_dir: Path
num_images: int = 100
batch_size: int = 16
model_id: str = "CompVis/stable-diffusion-v1-4"
device: str = "cuda"
@inject_args
def main(args: GenerateImagesParams) -> None:
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16)
pipe.to(args.device)
prompts = [p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p]
batch_size = args.batch_size
if args.num_images == 1:
args.output_dir.mkdir(exist_ok=True, parents=True)
count = 0
for chunk in chunked(prompts, batch_size):
output = pipe(chunk, num_images_per_prompt=1)
for image in output.images:
image.save(args.output_dir/f"{count:04d}.png")
count += 1
with (args.output_dir/"prompts.txt").open("w") as fp:
fp.write("\n".join(prompts))
else:
for prompt in prompts:
prompt_signature = md5(prompt.encode("utf-8")).hexdigest()
prompt_dir = args.output_dir/prompt_signature
prompt_dir.mkdir(exist_ok=True, parents=True)
(prompt_dir/"prompt.txt").write_text(prompt)
left_images = args.num_images
count = 0
while left_images > 0:
output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images))
for image in output.images:
image.save(prompt_dir/f"{count:04d}.png")
count += 1
left_images -= batch_size
if __name__ == "__main__":
main()