Spaces:
Sleeping
Sleeping
''' | |
Converts a transformers model to a format compatible with flexgen. | |
''' | |
import argparse | |
import os | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) | |
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") | |
args = parser.parse_args() | |
def disable_torch_init(): | |
""" | |
Disable the redundant torch default initialization to accelerate model creation. | |
""" | |
import torch | |
global torch_linear_init_backup | |
global torch_layer_norm_init_backup | |
torch_linear_init_backup = torch.nn.Linear.reset_parameters | |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | |
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters | |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | |
def restore_torch_init(): | |
"""Rollback the change made by disable_torch_init.""" | |
import torch | |
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) | |
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) | |
if __name__ == '__main__': | |
path = Path(args.MODEL) | |
model_name = path.name | |
print(f"Loading {model_name}...") | |
# disable_torch_init() | |
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
# restore_torch_init() | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
out_folder = Path(f"models/{model_name}-np") | |
if not Path(out_folder).exists(): | |
os.mkdir(out_folder) | |
print(f"Saving the converted model to {out_folder}...") | |
for name, param in tqdm(list(model.model.named_parameters())): | |
name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") | |
param_path = os.path.join(out_folder, name) | |
with open(param_path, "wb") as f: | |
np.save(f, param.cpu().detach().numpy()) | |