Cropinky commited on
Commit
b422fa9
1 Parent(s): c2321f7
Files changed (2) hide show
  1. app.py +0 -62
  2. image_generator.py +143 -0
app.py CHANGED
@@ -7,68 +7,6 @@ from networks_fastgan import MyGenerator
7
  import click
8
  import PIL
9
 
10
- @click.command()
11
- @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', default = 10-15, required=True)
12
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
13
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
14
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
15
- @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
16
- @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
17
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
18
- def generate_images(
19
- seeds: List[int],
20
- truncation_psi: float,
21
- noise_mode: str,
22
- outdir: str,
23
- translate: Tuple[float,float],
24
- rotate: float,
25
- class_idx: Optional[int]
26
- ):
27
- """Generate images using pretrained network pickle.
28
-
29
- Examples:
30
-
31
- \b
32
- # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
33
- python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
34
- --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
35
-
36
- \b
37
- # Generate uncurated images with truncation using the MetFaces-U dataset
38
- python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
39
- --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
40
- """
41
-
42
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
43
- G = MyGenerator.from_pretrained("Cropinky/projected_gan_impressionism")
44
- print("network loaded")
45
- # Labels.
46
- label = torch.zeros([1, G.c_dim], device=device)
47
- if G.c_dim != 0:
48
- if class_idx is None:
49
- raise click.ClickException('Must specify class label with --class when using a conditional network')
50
- label[:, class_idx] = 1
51
- else:
52
- if class_idx is not None:
53
- print ('warn: --class=lbl ignored when running on an unconditional network')
54
-
55
- # Generate images.
56
- for seed_idx, seed in enumerate(seeds):
57
- print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
58
- z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
59
-
60
- # Construct an inverse rotation/translation matrix and pass to the generator. The
61
- # generator expects this matrix as an inverse to avoid potentially failing numerical
62
- # operations in the network.
63
- if hasattr(G.synthesis, 'input'):
64
- m = make_transform(translate, rotate)
65
- m = np.linalg.inv(m)
66
- G.synthesis.input.transform.copy_(torch.from_numpy(m))
67
-
68
- img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
69
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
70
- PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
71
-
72
 
73
 
74
  def image_generation(model, number_of_images=1):
 
7
  import click
8
  import PIL
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def image_generation(model, number_of_images=1):
image_generator.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ import legacy
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ def parse_range(s: Union[str, List]) -> List[int]:
25
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
26
+
27
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
28
+ '''
29
+ if isinstance(s, list): return s
30
+ ranges = []
31
+ range_re = re.compile(r'^(\d+)-(\d+)$')
32
+ for p in s.split(','):
33
+ m = range_re.match(p)
34
+ if m:
35
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
36
+ else:
37
+ ranges.append(int(p))
38
+ return ranges
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
43
+ '''Parse a floating point 2-vector of syntax 'a,b'.
44
+
45
+ Example:
46
+ '0,1' returns (0,1)
47
+ '''
48
+ if isinstance(s, tuple): return s
49
+ parts = s.split(',')
50
+ if len(parts) == 2:
51
+ return (float(parts[0]), float(parts[1]))
52
+ raise ValueError(f'cannot parse 2-vector {s}')
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def make_transform(translate: Tuple[float,float], angle: float):
57
+ m = np.eye(3)
58
+ s = np.sin(angle/360.0*np.pi*2)
59
+ c = np.cos(angle/360.0*np.pi*2)
60
+ m[0][0] = c
61
+ m[0][1] = s
62
+ m[0][2] = translate[0]
63
+ m[1][0] = -s
64
+ m[1][1] = c
65
+ m[1][2] = translate[1]
66
+ return m
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ @click.command()
71
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
72
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
73
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
74
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
75
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
76
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
77
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
78
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
79
+ def generate_images(
80
+ network_pkl: str,
81
+ seeds: List[int],
82
+ truncation_psi: float,
83
+ noise_mode: str,
84
+ outdir: str,
85
+ translate: Tuple[float,float],
86
+ rotate: float,
87
+ class_idx: Optional[int]
88
+ ):
89
+ """Generate images using pretrained network pickle.
90
+
91
+ Examples:
92
+
93
+ \b
94
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
95
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
96
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
97
+
98
+ \b
99
+ # Generate uncurated images with truncation using the MetFaces-U dataset
100
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
101
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
102
+ """
103
+
104
+ print('Loading networks from "%s"...' % network_pkl)
105
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
106
+ with dnnlib.util.open_url(network_pkl) as f:
107
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
108
+
109
+ os.makedirs(outdhf_ShHdFYTLiCYbxIuldloQmxMXgPvuEyUPAkir, exist_ok=True)
110
+
111
+ # Labels.
112
+ label = torch.zeros([1, G.c_dim], device=device)
113
+ if G.c_dim != 0:
114
+ if class_idx is None:
115
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
116
+ label[:, class_idx] = 1
117
+ else:
118
+ if class_idx is not None:
119
+ print ('warn: --class=lbl ignored when running on an unconditional network')
120
+
121
+ # Generate images.
122
+ for seed_idx, seed in enumerate(seeds):
123
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
124
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
125
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
126
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
127
+ # operations in the network.
128
+ if hasattr(G.synthesis, 'input'):
129
+ m = make_transform(translate, rotate)
130
+ m = np.linalg.inv(m)
131
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
132
+
133
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
134
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
135
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
136
+
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ if __name__ == "__main__":
141
+ generate_images() # pylint: disable=no-value-for-parameter
142
+
143
+ #----------------------------------------------------------------------------