hysts HF staff commited on
Commit
0dbb9c8
1 Parent(s): 9b6b1a2
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -66,7 +66,7 @@ def load_encoder(device: torch.device) -> nn.Module:
66
  use_auth_token=TOKEN)
67
  ckpt = torch.load(ckpt_path, map_location='cpu')
68
  opts = ckpt['opts']
69
- opts['device'] = 'cpu'
70
  opts['checkpoint_path'] = ckpt_path
71
  opts = argparse.Namespace(**opts)
72
  model = pSp(opts)
@@ -151,7 +151,7 @@ def run(
151
  resize=False)
152
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
153
 
154
- latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1)
155
  # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
156
  latent[1, 7:18] = instyle[0, 7:18]
157
  exstyle = generator.generator.style(
 
66
  use_auth_token=TOKEN)
67
  ckpt = torch.load(ckpt_path, map_location='cpu')
68
  opts = ckpt['opts']
69
+ opts['device'] = device.type
70
  opts['checkpoint_path'] = ckpt_path
71
  opts = argparse.Namespace(**opts)
72
  model = pSp(opts)
 
151
  resize=False)
152
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
153
 
154
+ latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1).to(device)
155
  # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
156
  latent[1, 7:18] = instyle[0, 7:18]
157
  exstyle = generator.generator.style(