File size: 1,026 Bytes
a50a629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import torch
from transformers import CognitivessConfig, CognitivessForCausalLM

def convert_cognitivess_checkpoint_to_hf(model_dir, save_dir):
    config = CognitivessConfig.from_pretrained(model_dir)
    model = CognitivessForCausalLM(config)
    
    # Load the model weights from the Cognitivess checkpoint
    state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
    model.load_state_dict(state_dict)
    
    # Save the model in Hugging Face format
    model.save_pretrained(save_dir)
    config.save_pretrained(save_dir)
    print(f"Model converted and saved to {save_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, required=True, help="Path to the Cognitivess model directory")
    parser.add_argument("--save_dir", type=str, required=True, help="Path to the directory to save the converted model")
    args = parser.parse_args()
    convert_cognitivess_checkpoint_to_hf(args.model_dir, args.save_dir)