|
import json |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from modeling_gpt1 import GPT1ForCausalLM, GPT1Model |
|
from configuration_gpt1 import GPT1Config |
|
|
|
|
|
GPT1Config.register_for_auto_class() |
|
GPT1Model.register_for_auto_class('AutoModel') |
|
GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM') |
|
|
|
def lists_are_equal(list1, list2): |
|
for i, j in zip(list1, list2): |
|
if i != j: |
|
return False |
|
return True |
|
|
|
|
|
def get_weights_from_tf_model(): |
|
|
|
shapes = json.load(open('original_gpt1_params/params_shapes.json')) |
|
offsets = np.cumsum([np.prod(shape) for shape in shapes]) |
|
|
|
init_params = [np.load('original_gpt1_params/params_{}.npy'.format(n)) for n in range(10)] |
|
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] |
|
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] |
|
|
|
config = GPT1Config() |
|
model = GPT1ForCausalLM(config) |
|
|
|
|
|
|
|
|
|
|
|
model.model.embs.weight.data = torch.from_numpy(init_params[1]) |
|
|
|
|
|
model.model.pos_emb.weight.data = torch.from_numpy(init_params[0]) |
|
|
|
layers = model.model.layers |
|
|
|
for i in range(0, 12): |
|
|
|
idx = 12 * i + 2 |
|
|
|
|
|
init_params[idx] = np.squeeze(init_params[idx], axis=0) |
|
q, k, v = torch.split(torch.tensor(init_params[idx]), 768, dim=-1) |
|
layers[i].attention.q_proj.weight.data = q.detach().clone().transpose(-1, -2).contiguous() |
|
layers[i].attention.k_proj.weight.data = k.detach().clone().transpose(-1, -2).contiguous() |
|
layers[i].attention.v_proj.weight.data = v.detach().clone().transpose(-1, -2).contiguous() |
|
|
|
|
|
q_bias, k_bias, v_bias = torch.split(torch.tensor(init_params[idx + 1]), 768, dim=-1) |
|
layers[i].attention.q_proj.bias.data = q_bias.detach().clone().contiguous() |
|
layers[i].attention.k_proj.bias.data = k_bias.detach().clone().contiguous() |
|
layers[i].attention.v_proj.bias.data = v_bias.detach().clone().contiguous() |
|
|
|
|
|
init_params[idx + 2] = np.squeeze(init_params[idx + 2], axis=0) |
|
layers[i].attention.o_proj.weight.data = torch.from_numpy(init_params[idx + 2]).transpose(-1, -2).contiguous() |
|
layers[i].attention.o_proj.bias.data = torch.from_numpy(init_params[idx + 3]) |
|
|
|
|
|
layers[i].attention_norm.weight.data = torch.from_numpy(init_params[idx + 4]) |
|
layers[i].attention_norm.bias.data = torch.from_numpy(init_params[idx + 5]) |
|
|
|
|
|
init_params[idx + 6] = np.squeeze(init_params[idx + 6], axis=0) |
|
layers[i].mlp.fc1.weight.data = torch.from_numpy(init_params[idx + 6]).transpose(-1, -2).contiguous() |
|
layers[i].mlp.fc1.bias.data = torch.from_numpy(init_params[idx + 7]) |
|
init_params[idx + 8] = np.squeeze(init_params[idx + 8], axis=0) |
|
layers[i].mlp.fc2.weight.data = torch.from_numpy(init_params[idx + 8]).transpose(-1, -2).contiguous() |
|
layers[i].mlp.fc2.bias.data = torch.from_numpy(init_params[idx + 9]) |
|
|
|
|
|
layers[i].mlp_norm.weight.data = torch.from_numpy(init_params[idx + 10]) |
|
layers[i].mlp_norm.bias.data = torch.from_numpy(init_params[idx + 11]) |
|
|
|
model.save_pretrained('gpt1-converted-weights/') |
|
|
|
|
|
get_weights_from_tf_model() |
|
|