devforfu commited on
Commit
18045f9
1 Parent(s): 5c9258d

Generating fakes with SD

Browse files
Files changed (1) hide show
  1. realfake/bin/generate_images.py +44 -0
realfake/bin/generate_images.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates a set of new images to test the quality of the model.
3
+ """
4
+ from hashlib import md5
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from diffusers import StableDiffusionPipeline
9
+
10
+ from realfake.utils import Args, inject_args
11
+
12
+
13
+ class GenerateImagesParams(Args):
14
+ prompts_file: Path
15
+ output_dir: Path
16
+ num_images: int = 100
17
+ model_id: str = "CompVis/stable-diffusion-v1-4"
18
+ device: str = "cuda"
19
+
20
+
21
+ @inject_args
22
+ def main(args: GenerateImagesParams) -> None:
23
+ pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16)
24
+ pipe.to(args.device)
25
+
26
+ prompts = (p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p)
27
+
28
+ batch_size = 16
29
+ for prompt in prompts:
30
+ prompt_signature = md5(prompt.encode("utf-8")).hexdigest()
31
+ prompt_dir = args.output_dir/prompt_signature
32
+ prompt_dir.mkdir(exist_ok=True, parents=True)
33
+ (prompt_dir/"prompt.txt").write_text(prompt)
34
+
35
+ left_images = args.num_images
36
+ while left_images > 0:
37
+ output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images))
38
+ for i, image in enumerate(output.images):
39
+ image.save(prompt_dir/f"{i}.png")
40
+ left_images -= batch_size
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()