csuhan commited on
Commit
5f868e2
1 Parent(s): 8d8ef52
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -66,7 +66,7 @@ def load(
66
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
67
  # ckpt_path = checkpoints[local_rank]
68
  print("Loading")
69
- checkpoint = torch.load(ckpt_path, map_location="cpu")
70
  instruct_adapter_checkpoint = torch.load(
71
  instruct_adapter_path, map_location="cpu")
72
  caption_adapter_checkpoint = torch.load(
@@ -87,12 +87,13 @@ def load(
87
  model_args.vocab_size = tokenizer.n_words
88
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
89
  model = Transformer(model_args)
90
- vision_model = VisionModel(model_args)
91
-
92
- torch.set_default_tensor_type(torch.FloatTensor)
93
  model.load_state_dict(checkpoint, strict=False)
94
  del checkpoint
95
  torch.cuda.empty_cache()
 
 
 
 
96
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
97
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
98
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
 
66
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
67
  # ckpt_path = checkpoints[local_rank]
68
  print("Loading")
69
+ checkpoint = torch.load(ckpt_path, map_location="cuda")
70
  instruct_adapter_checkpoint = torch.load(
71
  instruct_adapter_path, map_location="cpu")
72
  caption_adapter_checkpoint = torch.load(
 
87
  model_args.vocab_size = tokenizer.n_words
88
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
89
  model = Transformer(model_args)
 
 
 
90
  model.load_state_dict(checkpoint, strict=False)
91
  del checkpoint
92
  torch.cuda.empty_cache()
93
+ vision_model = VisionModel(model_args)
94
+
95
+ torch.set_default_tensor_type(torch.FloatTensor)
96
+
97
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
98
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
99
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)