Doron Adler
* Updated model card
6e1c9c6
"""
Recreate the Core ML model from scratch using
coremltools' neural_network.NeuralNetworkBuilder
"""
import coremltools
import coremltools.models.datatypes as datatypes
from coremltools.models import neural_network as neural_network
from coremltools.models.utils import save_spec
import numpy as np
# get weights
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name = "./distilgpt2-base-pretrained-he"
save_directory = "tmp/coreml/"
#!mkdir -p $save_directory
file_name = "model.mlmodel"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
lm_head_model = GPT2LMHeadModel.from_pretrained(model_name).eval()
model = lm_head_model.transformer
wte = model.wte.weight.data.numpy().transpose() # shape (768, 50257) /!\ i hate this
wpe = model.wpe.weight.data.numpy().transpose() # shape (768, 1024)
sequence_length = 64
steps = 6
# build model
input_features = [
('input_ids', datatypes.Array(sequence_length)),
('position_ids', datatypes.Array(sequence_length)),
]
output_features = [('output_logits', None)]
builder = neural_network.NeuralNetworkBuilder(
input_features,
output_features,
mode=None,
disable_rank5_shape_mapping=True,
)
builder.add_expand_dims(
name='input_ids_expanded_to_rank5',
input_name='input_ids',
output_name='input_ids_expanded_to_rank5',
axes=(1, 2, 3, 4)
)
builder.add_expand_dims(
name='position_ids_expanded_to_rank5',
input_name='position_ids',
output_name='position_ids_expanded_to_rank5',
axes=(1, 2, 3, 4)
)
builder.add_embedding(
name='token_embeddings',
input_name='input_ids_expanded_to_rank5',
output_name='token_embeddings',
W=wte,
b=None,
input_dim=50257,
output_channels=768,
has_bias=False,
)
builder.add_embedding(
name='positional_embeddings',
input_name='position_ids_expanded_to_rank5',
output_name='positional_embeddings',
W=wpe,
b=None,
input_dim=1024,
output_channels=768,
has_bias=False,
)
# Input:, Output: (seq, 1, 768, 1, 1)
builder.add_add_broadcastable(
name='embeddings_addition',
input_names=['token_embeddings', 'positional_embeddings'],
output_name=f'{0}_previous_block'
)
for i in range(steps):
print(i)
ln_weight = model.h[i].ln_1.weight.data.numpy().reshape((1, 1, 768, 1, 1))
ln_bias = model.h[i].ln_1.bias.data.numpy().reshape((1, 1, 768, 1, 1))
ln_epsilon = model.h[i].ln_1.eps
builder.add_mvn(
name=f"{i}_block_ln_1",
input_name=f"{i}_previous_block",
# output_name=f"{i}_block_ln_1_output",
output_name=f"{i}_block_ln_1",
across_channels=True,
normalize_variance=True,
epsilon=ln_epsilon
)
builder.add_scale(
name=f"{i}_block_ln_1_scaled",
input_name=f"{i}_block_ln_1",
output_name=f"{i}_block_ln_1_scaled",
W=ln_weight,
b=ln_bias,
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)
builder.add_transpose(
name=f"{i}_block_ln_1_reshape",
input_name=f"{i}_block_ln_1_scaled",
output_name=f"{i}_block_ln_1_scaled_transposed",
axes=(1, 0, 2, 3, 4)
)
conv_1D_bias = model.h[i].attn.c_attn.bias.data.numpy().reshape((1, 1, 2304, 1, 1))
conv_1D_weights = model.h[i].attn.c_attn.weight.data.numpy().transpose().reshape((1, 768, 2304, 1, 1))
builder.add_inner_product(
name=f"{i}_block_attn_conv",
input_name=f"{i}_block_ln_1_scaled_transposed",
output_name=f"{i}_block_attn_conv",
input_channels=768,
output_channels=2304,
W=conv_1D_weights,
b=conv_1D_bias,
has_bias=True
)
builder.add_split(
name=f"{i}_block_attn_qkv_split",
input_name=f"{i}_block_attn_conv",
output_names=[f"{i}_block_attn_q", f"{i}_block_attn_k", f"{i}_block_attn_v"]
)
builder.add_rank_preserving_reshape(
name=f"{i}_block_attn_q_reshape",
input_name=f"{i}_block_attn_q",
output_name=f"{i}_block_attn_q_reshape",
output_shape=(1, 1, sequence_length, 12, 64)
)
builder.add_transpose(
name=f"{i}_block_attn_q_reshape_permuted",
input_name=f"{i}_block_attn_q_reshape",
output_name=f"{i}_block_attn_q_reshape_permuted",
axes=(0, 1, 3, 2, 4)
)
builder.add_rank_preserving_reshape(
name=f"{i}_block_attn_k_reshape",
input_name=f"{i}_block_attn_k",
output_name=f"{i}_block_attn_k_reshape",
output_shape=(1, 1, sequence_length, 12, 64)
)
builder.add_transpose(
name=f"{i}_block_attn_k_reshape_permuted",
input_name=f"{i}_block_attn_k_reshape",
output_name=f"{i}_block_attn_k_reshape_permuted",
axes=(0, 1, 3, 4, 2)
)
builder.add_rank_preserving_reshape(
name=f"{i}_block_attn_v_reshape",
input_name=f"{i}_block_attn_v",
output_name=f"{i}_block_attn_v_reshape",
output_shape=(1, 1, sequence_length, 12, 64)
)
builder.add_transpose(
name=f"{i}_block_attn_v_reshape_permuted",
input_name=f"{i}_block_attn_v_reshape",
output_name=f"{i}_block_attn_v_reshape_permuted",
axes=(0, 1, 3, 2, 4)
)
builder.add_batched_mat_mul(
name=f"{i}_block_attn_qv_matmul",
input_names=[f"{i}_block_attn_q_reshape_permuted", f"{i}_block_attn_k_reshape_permuted"],
output_name=f"{i}_block_attn_qv_matmul"
)
builder.add_scale(
name=f"{i}_block_attn_qv_matmul_scaled",
input_name=f"{i}_block_attn_qv_matmul",
output_name=f"{i}_block_attn_qv_matmul_scaled",
W=np.array(1/8),
b=0,
has_bias=False
)
bias_0 = model.h[i].attn.bias
nd = ns = sequence_length
b = (model.h[i].attn.bias[:, :, ns-nd:ns, :ns]).unsqueeze(0)
builder.add_scale(
name=f"{i}_block_attn_bias",
input_name=f"{i}_block_attn_qv_matmul_scaled",
output_name=f"{i}_block_attn_bias",
W=b,
b=None,
has_bias=False,
shape_scale=[1, sequence_length, sequence_length]
)
bias_constant_0 = - 1e4 * (1 - b)
builder.add_bias(
name=f"{i}_block_attn_afterbias",
input_name=f"{i}_block_attn_bias",
output_name=f"{i}_block_attn_afterbias",
# output_name=f"output_logits",
b=bias_constant_0,
shape_bias=[1, sequence_length, sequence_length],
)
builder.add_squeeze(
name=f"{i}_squeezit",
input_name=f"{i}_block_attn_afterbias",
output_name=f"{i}_squeezit",
axes=[0, 1]
)
builder.add_softmax(
name=f"{i}_block_attn_softmax",
input_name=f"{i}_squeezit",
output_name=f"{i}_block_attn_softmax",
)
builder.add_expand_dims(
name=f"{i}_expandit",
input_name=f"{i}_block_attn_softmax",
output_name=f"{i}_expandit",
axes=[0, 1]
)
builder.add_batched_mat_mul(
name=f"{i}_block_full_attention",
input_names=[f"{i}_expandit", f"{i}_block_attn_v_reshape_permuted"],
output_name=f"{i}_block_full_attention"
)
builder.add_transpose(
name=f"{i}_block_full_attention_merged_t",
input_name=f"{i}_block_full_attention",
output_name=f"{i}_block_full_attention_merged_t",
axes=[0, 1, 3, 2, 4]
)
builder.add_rank_preserving_reshape(
name=f"{i}_block_full_attention_merged",
input_name=f"{i}_block_full_attention_merged_t",
output_name=f"{i}_block_full_attention_merged",
output_shape=[1, 1, 1, sequence_length, 768]
)
builder.add_transpose(
name=f"{i}_block_attn_conv_proj_t",
input_name=f"{i}_block_full_attention_merged",
output_name=f"{i}_block_attn_conv_proj_t",
axes=[0, 3, 4, 1, 2]
)
conv_1D_proj_bias = model.h[i].attn.c_proj.bias.data.numpy().reshape((1, 1, 768, 1, 1))
conv_1D_proj_weights = model.h[i].attn.c_proj.weight.data.numpy().transpose().reshape((1, 768, 768, 1, 1))
# Input:, Output: (1, 3, 768, 1, 1)
builder.add_inner_product(
name=f"{i}_block_attn_conv_proj",
input_name=f"{i}_block_attn_conv_proj_t",
output_name=f"{i}_block_attn_conv_proj",
input_channels=768,
output_channels=768,
W=conv_1D_proj_weights,
b=conv_1D_proj_bias,
has_bias=True
)
# Input: (seq, 1, 768, 1, 1), Output: (1, seq, 768, 1, 1)
builder.add_transpose(
name=f"{i}_previous_block_t",
input_name=f'{i}_previous_block',
output_name=f"{i}_previous_block_t",
axes=[1, 0, 2, 3, 4]
)
# Input: [(1, seq, 768, 1, 1), (1, seq, 768, 1, 1)], Output: (1, seq, 768, 1, 1)
builder.add_add_broadcastable(
name=f"{i}_block_xa_sum",
input_names=[f"{i}_previous_block_t", f"{i}_block_attn_conv_proj"],
output_name=f"{i}_block_xa_sum",
# output_name=f"output_logits"
)
ln_2_weight = model.h[i].ln_2.weight.data.numpy().reshape((1, 1, 768, 1, 1))
ln_2_bias = model.h[i].ln_2.bias.data.numpy().reshape((1, 1, 768, 1, 1))
ln_2_epsilon = model.h[i].ln_2.eps
# Input: (1, seq, 768, 1, 1), Output:
builder.add_mvn(
name=f"{i}_block_ln_2",
input_name=f"{i}_block_xa_sum",
output_name=f"{i}_block_ln_2",
across_channels=True,
normalize_variance=True,
epsilon=ln_2_epsilon
)
builder.add_scale(
name=f"{i}_block_ln_2_scaled",
input_name=f"{i}_block_ln_2",
# output_name=f"output_logits",
output_name=f"{i}_block_ln_2_scaled",
W=ln_2_weight,
b=ln_2_bias,
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)
mlp_conv_1D_fc_bias = model.h[i].mlp.c_fc.bias.data.numpy().reshape((1, 1, 3072, 1, 1))
mlp_conv_1D_fc_weights = model.h[i].mlp.c_fc.weight.data.numpy().transpose().reshape((1, 768, 3072, 1, 1))
# Input:, Output: (1, 3, 3072, 1, 1)
builder.add_inner_product(
name=f"{i}_block_mlp_conv_fc",
input_name=f"{i}_block_ln_2_scaled",
output_name=f"{i}_block_mlp_conv_fc",
# output_name=f"output_logits",
input_channels=768,
output_channels=3072,
W=mlp_conv_1D_fc_weights,
b=mlp_conv_1D_fc_bias,
has_bias=True
)
builder.add_gelu(
name=f"{i}_block_mlp_gelu",
input_name=f"{i}_block_mlp_conv_fc",
output_name=f"{i}_block_mlp_gelu",
# output_name=f"output_logits",
mode='TANH_APPROXIMATION'
)
mlp_conv_1D_proj_bias = model.h[i].mlp.c_proj.bias.data.numpy().reshape((1, 1, 768, 1, 1))
mlp_conv_1D_proj_weights = model.h[i].mlp.c_proj.weight.data.numpy().transpose().reshape((1, 3072, 768, 1, 1))
# Input:, Output: (1, 3, 3072, 1, 1)
builder.add_inner_product(
name=f"{i}_block_mlp_conv_proj",
input_name=f"{i}_block_mlp_gelu",
output_name=f"{i}_block_mlp_conv_proj",
# output_name=f"output_logits",
input_channels=3072,
output_channels=768,
W=mlp_conv_1D_proj_weights,
b=mlp_conv_1D_proj_bias,
has_bias=True
)
builder.add_add_broadcastable(
name=f"{i}_block_xm_sum",
input_names=[f"{i}_block_xa_sum", f"{i}_block_mlp_conv_proj"],
# output_name=f"output_logits"
output_name=f"{i + 1}_previous_block_final"
)
builder.add_transpose(
name=f"{i}_block_xm_sum_t",
input_name=f"{i + 1}_previous_block_final",
output_name=f"{i + 1}_previous_block",
axes=[1, 0, 2, 3, 4]
)
ln_f_weight = model.ln_f.weight.data.numpy().reshape((1, 1, 768, 1, 1))
ln_f_bias = model.ln_f.bias.data.numpy().reshape((1, 1, 768, 1, 1))
ln_f_epsilon = model.ln_f.eps
# Input: (1, seq, 768, 1, 1), Output:
builder.add_mvn(
name=f"ln_f",
input_name=f"{steps}_previous_block_final",
output_name=f"ln_f",
# output_name=f"output_logits",
across_channels=True,
normalize_variance=True,
epsilon=ln_f_epsilon
)
builder.add_scale(
name=f"ln_f_scaled",
input_name=f"ln_f",
output_name=f"ln_f_scaled",
# output_name=f"output_logits",
W=ln_f_weight,
b=ln_f_bias,
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)
lm_head_weights = lm_head_model.lm_head.weight.data.numpy().reshape((1, 50257, 768, 1, 1))
builder.add_inner_product(
name="lm_head",
input_name="ln_f_scaled",
output_name="output_logits",
input_channels=768,
output_channels=50257,
W=lm_head_weights,
b=None,
has_bias=False
)
# compile spec to model
mlmodel = coremltools.models.MLModel(builder.spec)
#save_spec(builder.spec, f'./{model_name}-{sequence_length}-{steps}.mlmodel')
save_spec(builder.spec, f'./{save_directory}{file_name}')
# model = coremltools.models.MLModel('gpt2.mlmodel')
# input_ids = np.zeros(sequence_length)
# position_ids = np.arange(sequence_length).astype(np.float)
# input_data = {
# 'input_ids': input_ids,
# 'position_ids': position_ids,
# }
# predictions = mlmodel.predict(input_data)["output_logits"]
# equal = np.amax(predictions - mlp_conv_proj.detach().numpy())
# print(predictions)
# save_spec(builder.spec, 'gpt2.mlmodel')