JackyChunKit commited on
Commit
aae6741
·
verified ·
1 Parent(s): 6ec27fe

Upload convert_fsdp_to_hf.py

Browse files
Files changed (1) hide show
  1. convert_fsdp_to_hf.py +39 -0
convert_fsdp_to_hf.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ import fire
6
+ from glob import glob
7
+ from collections import defaultdict
8
+
9
+
10
+ def main(fsdp_checkpoint_path, huggingface_model_path, output_path):
11
+ state_dict = defaultdict(list)
12
+
13
+ world_size = 8
14
+ for rank in range(world_size):
15
+ filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
16
+ print('loading', filepath)
17
+ this_state_dict = torch.load(filepath)
18
+ for key, value in this_state_dict.items():
19
+ state_dict[key].append(value.to_local())
20
+
21
+ for key in state_dict:
22
+ state_dict[key] = torch.cat(state_dict[key], dim=0)
23
+
24
+ config = AutoConfig.from_pretrained(huggingface_model_path)
25
+ model = AutoModelForCausalLM.from_config(config)
26
+ model.load_state_dict(state_dict)
27
+
28
+ #for filepath in glob(f'{fsdp_checkpoint_path}/model_*.pt'):
29
+ # part_state_dict = torch.load(filepath)
30
+ # model.load_state_dict(part_state_dict)
31
+
32
+ model.save_pretrained(output_path)
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
35
+ tokenizer.save_pretrained(output_path)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ fire.Fire(main)