csuhan commited on
Commit
e2752b7
1 Parent(s): 5f868e2
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -49,7 +49,8 @@ def setup_model_parallel() -> Tuple[int, int]:
49
 
50
 
51
  def load(
52
- ckpt_path: str,
 
53
  param_path: str,
54
  tokenizer_path: str,
55
  instruct_adapter_path: str,
@@ -66,7 +67,7 @@ def load(
66
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
67
  # ckpt_path = checkpoints[local_rank]
68
  print("Loading")
69
- checkpoint = torch.load(ckpt_path, map_location="cuda")
70
  instruct_adapter_checkpoint = torch.load(
71
  instruct_adapter_path, map_location="cpu")
72
  caption_adapter_checkpoint = torch.load(
@@ -87,9 +88,18 @@ def load(
87
  model_args.vocab_size = tokenizer.n_words
88
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
89
  model = Transformer(model_args)
90
- model.load_state_dict(checkpoint, strict=False)
91
- del checkpoint
 
92
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
93
  vision_model = VisionModel(model_args)
94
 
95
  torch.set_default_tensor_type(torch.FloatTensor)
@@ -173,7 +183,10 @@ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
173
  # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
174
  # param_path = "/data1/llma/7B/params.json"
175
  # tokenizer_path = "/data1/llma/tokenizer.model"
176
- ckpt_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="consolidated.00.pth")
 
 
 
177
  param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
178
  tokenizer_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
179
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
@@ -190,7 +203,7 @@ if local_rank > 0:
190
  sys.stdout = open(os.devnull, "w")
191
 
192
  generator = load(
193
- ckpt_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
194
  )
195
 
196
 
 
49
 
50
 
51
  def load(
52
+ ckpt0_path: str,
53
+ ckpt1_path: str,
54
  param_path: str,
55
  tokenizer_path: str,
56
  instruct_adapter_path: str,
 
67
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
68
  # ckpt_path = checkpoints[local_rank]
69
  print("Loading")
70
+ # checkpoint = torch.load(ckpt_path, map_location="cuda")
71
  instruct_adapter_checkpoint = torch.load(
72
  instruct_adapter_path, map_location="cpu")
73
  caption_adapter_checkpoint = torch.load(
 
88
  model_args.vocab_size = tokenizer.n_words
89
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
90
  model = Transformer(model_args)
91
+ checkpoint1 = torch.load(ckpt0_path, map_location='cuda')
92
+ model.load_state_dict(checkpoint1, strict=False)
93
+ del checkpoint1
94
  torch.cuda.empty_cache()
95
+
96
+ checkpoint2 = torch.load(ckpt1_path, map_location='cuda')
97
+ model.load_state_dict(checkpoint2, strict=False)
98
+ del checkpoint2
99
+ torch.cuda.empty_cache()
100
+
101
+ # model.load_state_dict(checkpoint, strict=False)
102
+ # del checkpoint
103
  vision_model = VisionModel(model_args)
104
 
105
  torch.set_default_tensor_type(torch.FloatTensor)
 
183
  # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
184
  # param_path = "/data1/llma/7B/params.json"
185
  # tokenizer_path = "/data1/llma/tokenizer.model"
186
+ # ckpt_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="consolidated.00.pth")
187
+ # param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
188
+ ckpt0_path = hf_hub_download(repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
189
+ ckpt1_path = hf_hub_download(repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
190
  param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
191
  tokenizer_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
192
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
 
203
  sys.stdout = open(os.devnull, "w")
204
 
205
  generator = load(
206
+ ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
207
  )
208
 
209