|
""" |
|
Generates a set of new images to test the quality of the model. |
|
""" |
|
from hashlib import md5 |
|
from pathlib import Path |
|
|
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
from more_itertools import chunked |
|
|
|
from realfake.utils import Args, inject_args |
|
|
|
|
|
class GenerateImagesParams(Args): |
|
prompts_file: Path |
|
output_dir: Path |
|
num_images: int = 100 |
|
batch_size: int = 16 |
|
model_id: str = "CompVis/stable-diffusion-v1-4" |
|
device: str = "cuda" |
|
|
|
|
|
@inject_args |
|
def main(args: GenerateImagesParams) -> None: |
|
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16) |
|
pipe.to(args.device) |
|
prompts = [p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p] |
|
|
|
batch_size = args.batch_size |
|
|
|
if args.num_images == 1: |
|
args.output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
count = 0 |
|
for chunk in chunked(prompts, batch_size): |
|
output = pipe(chunk, num_images_per_prompt=1) |
|
for image in output.images: |
|
image.save(args.output_dir/f"{count:04d}.png") |
|
count += 1 |
|
|
|
with (args.output_dir/"prompts.txt").open("w") as fp: |
|
fp.write("\n".join(prompts)) |
|
|
|
else: |
|
for prompt in prompts: |
|
prompt_signature = md5(prompt.encode("utf-8")).hexdigest() |
|
prompt_dir = args.output_dir/prompt_signature |
|
prompt_dir.mkdir(exist_ok=True, parents=True) |
|
(prompt_dir/"prompt.txt").write_text(prompt) |
|
|
|
left_images = args.num_images |
|
count = 0 |
|
while left_images > 0: |
|
output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images)) |
|
for image in output.images: |
|
image.save(prompt_dir/f"{count:04d}.png") |
|
count += 1 |
|
left_images -= batch_size |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|