| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from collections import defaultdict |
| import torch |
| from gemma import config |
| from gemma import model as gemma_model |
| import numpy as np |
| import argparse |
| import os |
|
|
| |
|
|
| def check_file_exists(value): |
| if not os.path.exists(str(value)): |
| raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) |
| return value |
| |
|
|
| def check_model_types(value): |
| if str(value).lower() not in ["2b", "7b"]: |
| raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) |
| return value |
| |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--tokenizer", |
| dest="tokenizer", |
| default="models/tokenizer.spm", |
| help="Location of tokenizer file (.model or .spm)", |
| type=check_file_exists, |
| ) |
|
|
| parser.add_argument( |
| "--weights", |
| dest="weights", |
| default="models/gemma-2b-it.ckpt", |
| help="Location of input checkpoint file (.ckpt)", |
| type=check_file_exists, |
| ) |
|
|
| parser.add_argument( |
| "--output_file", |
| dest="output_file", |
| default="2bit-f32.sbs", |
| help="Location to write converted weights", |
| type=str, |
| ) |
|
|
| parser.add_argument( |
| "--model_type", |
| dest="model_type", |
| default="2b", |
| help="Model size / type (2b, 7b)", |
| type=check_model_types, |
| ) |
|
|
| args = parser.parse_args() |
|
|
|
|
| TRANSFORMATIONS = { |
| "2b":defaultdict( |
| lambda: lambda x: x, |
| { |
| "embedder.weight": lambda x: x, |
| "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), |
| "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), |
| "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| "mlp.down_proj.weight": lambda x: x, |
| } |
| ), |
| "7b":defaultdict( |
| lambda: lambda x: x, |
| { |
| "embedder.weight": lambda x: x, |
| "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), |
| "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), |
| "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| "mlp.down_proj.weight": lambda x: x, |
| } |
| ), |
| } |
|
|
| VALIDATIONS = { |
| "2b": { |
| "embedder.weight": lambda x: x.shape == (256000, 2048), |
| "model.norm.weight": lambda x: x.shape == (2048,), |
| "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), |
| "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), |
| "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), |
| "input_layernorm.weight": lambda x: x.shape == (2048,), |
| "post_attention_layernorm.weight": lambda x: x.shape == (2048,), |
| }, |
| "7b": { |
| "embedder.weight": lambda x: x.shape == (256000, 3072), |
| "model.norm.weight": lambda x: x.shape == (3072,), |
| "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), |
| "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), |
| "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), |
| "input_layernorm.weight": lambda x: x.shape == (3072,), |
| "post_attention_layernorm.weight": lambda x: x.shape == (3072,), |
| }, |
| } |
|
|
|
|
| def param_names(num_hidden_layers: int): |
| """Return parameter names in the order they are expected for deserialization.""" |
|
|
| |
| |
| |
| |
| |
|
|
| |
| names = [ |
| ("embedder.weight", ) * 2, |
| ("model.norm.weight", ) * 2 |
| ] |
| layer_params = [ |
| "self_attn.o_proj.weight", |
| "self_attn.qkv_proj.weight", |
| "mlp.gate_proj.weight", |
| "mlp.up_proj.weight", |
| "mlp.down_proj.weight", |
| "input_layernorm.weight", |
| "post_attention_layernorm.weight", |
| ] |
| |
| for layer in range(num_hidden_layers): |
| for layer_param in layer_params: |
| names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] |
| return names |
|
|
|
|
| def convert_weights(): |
| model_type = args.model_type |
| output_file = args.output_file |
| |
| model_config = config.get_model_config(model_type) |
| model_config.dtype = "float32" |
| model_config.tokenizer = args.tokenizer |
| device = torch.device("cpu") |
| torch.set_default_dtype(torch.float) |
| model = gemma_model.GemmaForCausalLM(model_config) |
| |
| model.load_weights(args.weights) |
| model.to(device).eval() |
| |
| model_dict = dict(model.named_parameters()) |
| param_order = param_names(model_config.num_hidden_layers) |
|
|
| all_ok = True |
| print("Checking transformations ...") |
| for name, layer_name in param_order: |
| arr = model_dict[name].detach().numpy() |
| arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
|
|
| if check == "FAILED": |
| all_ok = False |
| print(f" {name : <60}{str(arr.shape) : <20}{check}") |
|
|
| if all_ok: |
| print("Writing parameters ...") |
| gate = None |
| with open(output_file, "wb") as bin_handle: |
| for name, layer_name in param_order: |
| arr = model_dict[name].detach().numpy() |
| arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| arr.flatten().astype(np.float32).tofile(bin_handle) |
|
|
|
|
| if __name__ == "__main__": |
| convert_weights() |
| print("Done") |
|
|