Jiatao Gu commited on
Commit
368dc9b
1 Parent(s): 22790a0

fix some errors. update code

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +31 -35
  3. gradio_queue.db +0 -0
.gitignore CHANGED
@@ -23,3 +23,6 @@ scripts/research/
23
  .ipynb_checkpoints/
24
  _screenshots/
25
  flagged
 
 
 
23
  .ipynb_checkpoints/
24
  _screenshots/
25
  flagged
26
+
27
+ *.db
28
+ gradio_queue.db
app.py CHANGED
@@ -20,13 +20,20 @@ from huggingface_hub import hf_hub_download
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
22
 
 
 
 
 
 
 
 
23
 
24
  def set_random_seed(seed):
25
  torch.manual_seed(seed)
26
  np.random.seed(seed)
27
 
28
 
29
- def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512'):
30
  gen = model.synthesis
31
  range_u, range_v = gen.C.range_u, gen.C.range_v
32
  if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
@@ -41,22 +48,10 @@ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512
41
  return cam
42
 
43
 
44
- def check_name(model_name='FFHQ512'):
45
  """Gets model by name."""
46
- if model_name == 'FFHQ512':
47
- network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
48
-
49
- # TODO: checkpoint to be updated!
50
- # elif model_name == 'FFHQ512v2':
51
- # network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
52
- # elif model_name == 'AFHQ512':
53
- # network_pkl = "./pretrained/afhq_512.pkl"
54
- # elif model_name == 'MetFaces512':
55
- # network_pkl = "./pretrained/metfaces_512.pkl"
56
- # elif model_name == 'CompCars256':
57
- # network_pkl = "./pretrained/cars_256.pkl"
58
- # elif model_name == 'FFHQ1024':
59
- # network_pkl = "./pretrained/ffhq_1024.pkl"
60
  else:
61
  if os.path.isdir(model_name):
62
  network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
@@ -85,7 +80,7 @@ def get_model(network_pkl, render_option=None):
85
  return G2, res, imgs
86
 
87
 
88
- global_states = list(get_model(check_name()))
89
  wss = [None, None]
90
 
91
  def proc_seed(history, seed):
@@ -98,7 +93,8 @@ def proc_seed(history, seed):
98
  def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
99
  history = history or {}
100
  seeds = []
101
-
 
102
  if model_find != "":
103
  model_name = model_find
104
 
@@ -124,7 +120,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
124
  set_random_seed(seed)
125
  z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
126
  ws = model.mapping(z=z, c=None, truncation_psi=trunc)
127
- img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0), render_option=render_option)
128
  ws = ws.detach().cpu().numpy()
129
  img = img[0].permute(1,2,0).detach().cpu().numpy()
130
 
@@ -178,26 +174,26 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
178
  image = (image * 255).astype('uint8')
179
  return image, history
180
 
181
- model_name = gr.inputs.Dropdown(['FFHQ512']) # 'FFHQ512v2', 'AFHQ512', 'MetFaces512', 'CompCars256', 'FFHQ1024'
182
- model_find = gr.inputs.Textbox(label="checkpoint path", default="")
183
- render_option = gr.inputs.Textbox(label="rendering options", default='steps:40')
184
- trunc = gr.inputs.Slider(default=0.7, maximum=1.0, minimum=0.0, label='truncation trick')
185
- seed1 = gr.inputs.Number(default=1, label="seed1")
186
- seed2 = gr.inputs.Number(default=9, label="seed2")
187
- mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (geometry)")
188
- mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (apparence)")
189
- early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='intermedia output')
190
- yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="yaw")
191
- pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
192
- roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
193
- fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
194
  css = ".output-image, .input-image, .image-preview {height: 600px !important} "
195
 
196
  gr.Interface(fn=f_synthesis,
197
  inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
198
- title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
199
- description="Demo for ICLR 2022 Papaer: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
200
  outputs=["image", "state"],
201
  layout='unaligned',
202
- css=css, theme='dark-huggingface',
203
  live=True).launch(enable_queue=True)
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
22
 
23
+ model_lists = {
24
+ 'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
25
+ 'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'),
26
+ 'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'),
27
+ }
28
+ model_names = [name for name in model_lists]
29
+
30
 
31
  def set_random_seed(seed):
32
  torch.manual_seed(seed)
33
  np.random.seed(seed)
34
 
35
 
36
+ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
37
  gen = model.synthesis
38
  range_u, range_v = gen.C.range_u, gen.C.range_v
39
  if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
48
  return cam
49
 
50
 
51
+ def check_name(model_name):
52
  """Gets model by name."""
53
+ if model_name in model_lists:
54
+ network_pkl = hf_hub_download(**model_lists[model_name])
 
 
 
 
 
 
 
 
 
 
 
 
55
  else:
56
  if os.path.isdir(model_name):
57
  network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
80
  return G2, res, imgs
81
 
82
 
83
+ global_states = list(get_model(check_name(model_names[0])))
84
  wss = [None, None]
85
 
86
  def proc_seed(history, seed):
93
  def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
94
  history = history or {}
95
  seeds = []
96
+ trunc = trunc / 100
97
+
98
  if model_find != "":
99
  model_name = model_find
100
 
120
  set_random_seed(seed)
121
  z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
122
  ws = model.mapping(z=z, c=None, truncation_psi=trunc)
123
+ img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
124
  ws = ws.detach().cpu().numpy()
125
  img = img[0].permute(1,2,0).detach().cpu().numpy()
126
 
174
  image = (image * 255).astype('uint8')
175
  return image, history
176
 
177
+ model_name = gr.inputs.Dropdown(model_names)
178
+ model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
179
+ render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
180
+ trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
181
+ seed1 = gr.inputs.Number(default=1, label="Random seed1")
182
+ seed2 = gr.inputs.Number(default=9, label="Random seed2")
183
+ mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
184
+ mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
185
+ early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
186
+ yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
187
+ pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
188
+ roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
189
+ fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
190
  css = ".output-image, .input-image, .image-preview {height: 600px !important} "
191
 
192
  gr.Interface(fn=f_synthesis,
193
  inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
194
+ title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
195
+ description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
196
  outputs=["image", "state"],
197
  layout='unaligned',
198
+ css=css, theme='dark-seafoam',
199
  live=True).launch(enable_queue=True)
gradio_queue.db CHANGED
Binary files a/gradio_queue.db and b/gradio_queue.db differ