gpt1 / tf_weights_to_hf.py
Alexandru Gherghescu
Add original model weigts + conversion script
bbb5d39 unverified
raw
history blame
No virus
3.51 kB
import json
import torch
import numpy as np
from modeling_gpt1 import GPT1ForCausalLM, GPT1Model
from configuration_gpt1 import GPT1Config
GPT1Config.register_for_auto_class()
GPT1Model.register_for_auto_class('AutoModel')
GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM')
def lists_are_equal(list1, list2):
for i, j in zip(list1, list2):
if i != j:
return False
return True
# get the original weights from the GPT1 params.npy files
def get_weights_from_tf_model():
shapes = json.load(open('original_gpt1_params/params_shapes.json'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load('original_gpt1_params/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
config = GPT1Config()
model = GPT1ForCausalLM(config)
# print(shapes[:15])
# print([k for k, v in model.named_parameters()][:10])
# embs layer
model.model.embs.weight.data = torch.from_numpy(init_params[1])
# pos enc layer
model.model.pos_emb.weight.data = torch.from_numpy(init_params[0])
layers = model.model.layers
for i in range(0, 12):
idx = 12 * i + 2
# attention q, k, v projections
init_params[idx] = np.squeeze(init_params[idx], axis=0)
q, k, v = torch.split(torch.tensor(init_params[idx]), 768, dim=-1)
layers[i].attention.q_proj.weight.data = q.detach().clone().transpose(-1, -2).contiguous()
layers[i].attention.k_proj.weight.data = k.detach().clone().transpose(-1, -2).contiguous()
layers[i].attention.v_proj.weight.data = v.detach().clone().transpose(-1, -2).contiguous()
# attention q, k, v biases
q_bias, k_bias, v_bias = torch.split(torch.tensor(init_params[idx + 1]), 768, dim=-1)
layers[i].attention.q_proj.bias.data = q_bias.detach().clone().contiguous()
layers[i].attention.k_proj.bias.data = k_bias.detach().clone().contiguous()
layers[i].attention.v_proj.bias.data = v_bias.detach().clone().contiguous()
# attention output proj + bias
init_params[idx + 2] = np.squeeze(init_params[idx + 2], axis=0)
layers[i].attention.o_proj.weight.data = torch.from_numpy(init_params[idx + 2]).transpose(-1, -2).contiguous()
layers[i].attention.o_proj.bias.data = torch.from_numpy(init_params[idx + 3])
# attention norm + bias
layers[i].attention_norm.weight.data = torch.from_numpy(init_params[idx + 4])
layers[i].attention_norm.bias.data = torch.from_numpy(init_params[idx + 5])
# mlp layer
init_params[idx + 6] = np.squeeze(init_params[idx + 6], axis=0)
layers[i].mlp.fc1.weight.data = torch.from_numpy(init_params[idx + 6]).transpose(-1, -2).contiguous()
layers[i].mlp.fc1.bias.data = torch.from_numpy(init_params[idx + 7])
init_params[idx + 8] = np.squeeze(init_params[idx + 8], axis=0)
layers[i].mlp.fc2.weight.data = torch.from_numpy(init_params[idx + 8]).transpose(-1, -2).contiguous()
layers[i].mlp.fc2.bias.data = torch.from_numpy(init_params[idx + 9])
# mlp norm + bias
layers[i].mlp_norm.weight.data = torch.from_numpy(init_params[idx + 10])
layers[i].mlp_norm.bias.data = torch.from_numpy(init_params[idx + 11])
model.save_pretrained('gpt1-converted-weights/')
get_weights_from_tf_model()