apolinario commited on
Commit
d146181
1 Parent(s): cc182cb

Get the loading back to CUDA

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -27,7 +27,7 @@ import open_clip
27
 
28
  def load_model_from_config(config, ckpt, verbose=False):
29
  print(f"Loading model from {ckpt}")
30
- pl_sd = torch.load(ckpt, map_location="cpu")
31
  sd = pl_sd["state_dict"]
32
  model = instantiate_from_config(config.model)
33
  m, u = model.load_state_dict(sd, strict=False)
 
27
 
28
  def load_model_from_config(config, ckpt, verbose=False):
29
  print(f"Loading model from {ckpt}")
30
+ pl_sd = torch.load(ckpt, map_location="cuda")
31
  sd = pl_sd["state_dict"]
32
  model = instantiate_from_config(config.model)
33
  m, u = model.load_state_dict(sd, strict=False)