chansung commited on
Commit
e2ea425
1 Parent(s): 0e25a0e

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +1 -1
gen.py CHANGED
@@ -63,7 +63,7 @@ def get_pretrained_models(
63
 
64
  llama_ckpt_path = checkpoints[local_rank]
65
  print("Loading")
66
- checkpoint = torch.load(llama_ckpt_path, map_location=lambda storage, loc: storage.cuda(0))
67
  with open(Path(llama_weight_path) / "params.json", "r") as f:
68
  params = json.loads(f.read())
69
 
 
63
 
64
  llama_ckpt_path = checkpoints[local_rank]
65
  print("Loading")
66
+ checkpoint = torch.load(llama_ckpt_path, map_location="cpu")
67
  with open(Path(llama_weight_path) / "params.json", "r") as f:
68
  params = json.loads(f.read())
69