Jiatao Gu commited on
Commit
77c753d
2 Parent(s): 94ada0b 6b4302d

resolve conflict

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -28,7 +28,7 @@ def handler(signum, frame):
28
  if res == 'y':
29
  gr.close_all()
30
  exit(1)
31
-
32
  signal.signal(signal.SIGINT, handler)
33
 
34
 
@@ -56,7 +56,7 @@ def check_name(model_name='FFHQ512'):
56
  """Gets model by name."""
57
  if model_name == 'FFHQ512':
58
  network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
59
-
60
  # TODO: checkpoint to be updated!
61
  # elif model_name == 'FFHQ512v2':
62
  # network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
@@ -109,10 +109,10 @@ def proc_seed(history, seed):
109
  def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
110
  history = history or {}
111
  seeds = []
112
-
113
  if model_find != "":
114
  model_name = model_find
115
-
116
  model_name = check_name(model_name)
117
  if model_name != history.get("model_name", None):
118
  model, res, imgs = get_model(model_name, render_option)
@@ -139,7 +139,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
139
  ws = ws.detach().cpu().numpy()
140
  img = img[0].permute(1,2,0).detach().cpu().numpy()
141
 
142
-
143
  imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
144
  np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
145
  (res//2, res//2), cv2.INTER_AREA)
@@ -151,7 +151,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
151
  history[f'seed{idx}'] = seed
152
  history['trunc'] = trunc
153
  history['model_name'] = model_name
154
-
155
  set_random_seed(sum(seeds))
156
 
157
  # style mixing (?)
@@ -159,18 +159,18 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
159
  ws = ws1.clone()
160
  ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
161
  ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
162
-
163
  # set visualization for other types of inputs.
164
  if early == 'Normal Map':
165
  render_option += ',normal,early'
166
  elif early == 'Gradient Map':
167
  render_option += ',gradient,early'
168
-
169
  start_t = time.time()
170
  with torch.no_grad():
171
  cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
172
  image = model.get_final_output(
173
- styles=ws, camera_matrices=cam,
174
  theta=roll * np.pi,
175
  render_option=render_option)
176
  end_t = time.time()
@@ -184,7 +184,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
184
  b = int(imgs.shape[1] / imgs.shape[0] * a)
185
  print(f'resize {a} {b} {image.shape} {imgs.shape}')
186
  image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
187
-
188
  print(f'rendering time = {end_t-start_t:.4f}s')
189
  image = (image * 255).astype('uint8')
190
  return image, history
@@ -210,4 +210,4 @@ gr.Interface(fn=f_synthesis,
210
  outputs=["image", "state"],
211
  layout='unaligned',
212
  css=css, theme='dark-huggingface',
213
- live=True).launch(server_port=port)
28
  if res == 'y':
29
  gr.close_all()
30
  exit(1)
31
+
32
  signal.signal(signal.SIGINT, handler)
33
 
34
 
56
  """Gets model by name."""
57
  if model_name == 'FFHQ512':
58
  network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
59
+
60
  # TODO: checkpoint to be updated!
61
  # elif model_name == 'FFHQ512v2':
62
  # network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
109
  def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
110
  history = history or {}
111
  seeds = []
112
+
113
  if model_find != "":
114
  model_name = model_find
115
+
116
  model_name = check_name(model_name)
117
  if model_name != history.get("model_name", None):
118
  model, res, imgs = get_model(model_name, render_option)
139
  ws = ws.detach().cpu().numpy()
140
  img = img[0].permute(1,2,0).detach().cpu().numpy()
141
 
142
+
143
  imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
144
  np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
145
  (res//2, res//2), cv2.INTER_AREA)
151
  history[f'seed{idx}'] = seed
152
  history['trunc'] = trunc
153
  history['model_name'] = model_name
154
+
155
  set_random_seed(sum(seeds))
156
 
157
  # style mixing (?)
159
  ws = ws1.clone()
160
  ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
161
  ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
162
+
163
  # set visualization for other types of inputs.
164
  if early == 'Normal Map':
165
  render_option += ',normal,early'
166
  elif early == 'Gradient Map':
167
  render_option += ',gradient,early'
168
+
169
  start_t = time.time()
170
  with torch.no_grad():
171
  cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
172
  image = model.get_final_output(
173
+ styles=ws, camera_matrices=cam,
174
  theta=roll * np.pi,
175
  render_option=render_option)
176
  end_t = time.time()
184
  b = int(imgs.shape[1] / imgs.shape[0] * a)
185
  print(f'resize {a} {b} {image.shape} {imgs.shape}')
186
  image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
187
+
188
  print(f'rendering time = {end_t-start_t:.4f}s')
189
  image = (image * 255).astype('uint8')
190
  return image, history
210
  outputs=["image", "state"],
211
  layout='unaligned',
212
  css=css, theme='dark-huggingface',
213
+ live=True).launch(enable_queue=True)