rinong commited on
Commit
720df82
1 Parent(s): ea00796

Cuda uses now check for device

Browse files
Files changed (2) hide show
  1. app.py +1 -2
  2. generate_videos.py +4 -2
app.py CHANGED
@@ -92,7 +92,6 @@ class ImageEditor(object):
92
 
93
  self.e4e_net = pSp(opts, self.device)
94
  self.e4e_net.eval()
95
- self.e4e_net.cuda()
96
 
97
  self.shape_predictor = dlib.shape_predictor(
98
  model_paths["dlib"]
@@ -192,7 +191,7 @@ class ImageEditor(object):
192
 
193
  def run_on_batch(self, inputs):
194
  images, latents = self.e4e_net(
195
- inputs.to("cuda").float(), randomize_noise=False, return_latents=True
196
  )
197
  return images, latents
198
 
 
92
 
93
  self.e4e_net = pSp(opts, self.device)
94
  self.e4e_net.eval()
 
95
 
96
  self.shape_predictor = dlib.shape_predictor(
97
  model_paths["dlib"]
 
191
 
192
  def run_on_batch(self, inputs):
193
  images, latents = self.e4e_net(
194
+ inputs.to(self.device).float(), randomize_noise=False, return_latents=True
195
  )
196
  return images, latents
197
 
generate_videos.py CHANGED
@@ -52,6 +52,8 @@ def project_code(latent_code, boundary, distance=3.0):
52
 
53
  def generate_frames(args, source_latent, g_ema_list, output_dir):
54
 
 
 
55
  alphas = np.linspace(0, 1, num=20)
56
 
57
  interpolate_func = interpolate_with_boundaries # default
@@ -84,7 +86,7 @@ def generate_frames(args, source_latent, g_ema_list, output_dir):
84
  src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)
85
 
86
  if idx == 0 or segments or latent is not latents[idx - 1]:
87
- w = torch.from_numpy(latent).float().cuda()
88
 
89
  with torch.no_grad():
90
  img, _ = g_ema([w], input_is_latent=True, truncation=1, randomize_noise=False)
@@ -205,7 +207,7 @@ def vid_to_gif(vid_path, output_dir, scale=256, fps=35):
205
 
206
 
207
  if __name__ == '__main__':
208
- device = 'cuda'
209
 
210
  parser = argparse.ArgumentParser()
211
 
 
52
 
53
  def generate_frames(args, source_latent, g_ema_list, output_dir):
54
 
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
  alphas = np.linspace(0, 1, num=20)
58
 
59
  interpolate_func = interpolate_with_boundaries # default
 
86
  src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)
87
 
88
  if idx == 0 or segments or latent is not latents[idx - 1]:
89
+ w = torch.from_numpy(latent).float().to(device)
90
 
91
  with torch.no_grad():
92
  img, _ = g_ema([w], input_is_latent=True, truncation=1, randomize_noise=False)
 
207
 
208
 
209
  if __name__ == '__main__':
210
+ device = "cuda" if torch.cuda.is_available() else "cpu"
211
 
212
  parser = argparse.ArgumentParser()
213