devforfu commited on
Commit
2903d34
1 Parent(s): df2b039

Generating one image per prompt

Browse files
Files changed (1) hide show
  1. realfake/bin/generate_images.py +29 -13
realfake/bin/generate_images.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
 
7
  import torch
8
  from diffusers import StableDiffusionPipeline
 
9
 
10
  from realfake.utils import Args, inject_args
11
 
@@ -14,6 +15,7 @@ class GenerateImagesParams(Args):
14
  prompts_file: Path
15
  output_dir: Path
16
  num_images: int = 100
 
17
  model_id: str = "CompVis/stable-diffusion-v1-4"
18
  device: str = "cuda"
19
 
@@ -22,24 +24,38 @@ class GenerateImagesParams(Args):
22
  def main(args: GenerateImagesParams) -> None:
23
  pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16)
24
  pipe.to(args.device)
25
-
26
- prompts = (p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p)
27
 
28
- batch_size = 16
29
- for prompt in prompts:
30
- prompt_signature = md5(prompt.encode("utf-8")).hexdigest()
31
- prompt_dir = args.output_dir/prompt_signature
32
- prompt_dir.mkdir(exist_ok=True, parents=True)
33
- (prompt_dir/"prompt.txt").write_text(prompt)
34
 
35
- left_images = args.num_images
 
 
36
  count = 0
37
- while left_images > 0:
38
- output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images))
39
  for image in output.images:
40
- image.save(prompt_dir/f"{count}.png")
41
  count += 1
42
- left_images -= batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  if __name__ == "__main__":
 
6
 
7
  import torch
8
  from diffusers import StableDiffusionPipeline
9
+ from more_itertools import chunked
10
 
11
  from realfake.utils import Args, inject_args
12
 
 
15
  prompts_file: Path
16
  output_dir: Path
17
  num_images: int = 100
18
+ batch_size: int = 16
19
  model_id: str = "CompVis/stable-diffusion-v1-4"
20
  device: str = "cuda"
21
 
 
24
  def main(args: GenerateImagesParams) -> None:
25
  pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16)
26
  pipe.to(args.device)
27
+ prompts = [p for p in (p.strip() for p in args.prompts_file.read_text().split("\n")) if p]
 
28
 
29
+ batch_size = args.batch_size
 
 
 
 
 
30
 
31
+ if args.num_images == 1:
32
+ args.output_dir.mkdir(exist_ok=True, parents=True)
33
+
34
  count = 0
35
+ for chunk in chunked(prompts, batch_size):
36
+ output = pipe(chunk, num_images_per_prompt=1)
37
  for image in output.images:
38
+ image.save(args.output_dir/f"{count:04d}.png")
39
  count += 1
40
+
41
+ with (args.output_dir/"prompts.txt").open("w") as fp:
42
+ fp.write("\n".join(prompts))
43
+
44
+ else:
45
+ for prompt in prompts:
46
+ prompt_signature = md5(prompt.encode("utf-8")).hexdigest()
47
+ prompt_dir = args.output_dir/prompt_signature
48
+ prompt_dir.mkdir(exist_ok=True, parents=True)
49
+ (prompt_dir/"prompt.txt").write_text(prompt)
50
+
51
+ left_images = args.num_images
52
+ count = 0
53
+ while left_images > 0:
54
+ output = pipe(prompt, num_images_per_prompt=min(batch_size, left_images))
55
+ for image in output.images:
56
+ image.save(prompt_dir/f"{count:04d}.png")
57
+ count += 1
58
+ left_images -= batch_size
59
 
60
 
61
  if __name__ == "__main__":