|
|
import os |
|
|
import struct |
|
|
import argparse |
|
|
import json |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def SerializeFP32(file, tensor): |
|
|
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() |
|
|
b = struct.pack(f'{len(d)}f', *d) |
|
|
file.write(b) |
|
|
|
|
|
|
|
|
def SerializeINT8(file, tensor): |
|
|
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) |
|
|
b = struct.pack(f'{len(d)}b', *d) |
|
|
file.write(b) |
|
|
|
|
|
|
|
|
def QuantizeINT8(w, group_size): |
|
|
assert w.numel() % group_size == 0 |
|
|
ori_shape = w.shape |
|
|
w = w.float() |
|
|
w = w.reshape(-1, group_size) |
|
|
|
|
|
wmax = torch.abs(w).max(dim=1).values |
|
|
scale = wmax / 127.0 |
|
|
quant = w / scale[:,None] |
|
|
|
|
|
int8val = torch.round(quant).to(torch.int8) |
|
|
fp32val = (int8val.float() * scale[:,None]).view(-1) |
|
|
fp32valr = fp32val.reshape(-1, group_size) |
|
|
|
|
|
err = torch.abs(fp32valr - w).max(dim=1).values |
|
|
maxerr = err.max().item() |
|
|
|
|
|
return int8val, scale, maxerr |
|
|
|
|
|
|
|
|
def WriteWeightsFP32(file, model, key): |
|
|
print(f"writing {key} {list(model[key].shape)[::-1]}") |
|
|
SerializeFP32(file, model[key]) |
|
|
|
|
|
|
|
|
def WriteWeightsINT8(file, model, key, group_size=64): |
|
|
""" writes the quantized layer weights to file """ |
|
|
q, s, err = QuantizeINT8(model[key], group_size) |
|
|
|
|
|
SerializeINT8(file, q) |
|
|
SerializeFP32(file, s) |
|
|
|
|
|
print(f"{key} quantized {tuple(model[key].shape)} to Q8_0 with max error {err}") |
|
|
|
|
|
|
|
|
def WriteLayersFP32(file, model, layer, n_layers): |
|
|
""" writes the layer weights to file """ |
|
|
for n in range(n_layers): |
|
|
WriteWeightsFP32(file, model, layer % n) |
|
|
|
|
|
|
|
|
def WriteLayersINT8(file, model, layer, n_layers, group_size=64): |
|
|
qtensors = { "q": [], "s": [] } |
|
|
|
|
|
for n in range(n_layers): |
|
|
q, s, err = QuantizeINT8(model[layer % n], group_size) |
|
|
|
|
|
qtensors["q"].append(q) |
|
|
qtensors["s"].append(s) |
|
|
|
|
|
|
|
|
print(f"{layer % n} quantized {tuple(model[layer % n].shape)} to Q8_0 with max error {err}") |
|
|
|
|
|
|
|
|
for q in qtensors["q"]: |
|
|
SerializeINT8(file, q) |
|
|
|
|
|
for s in qtensors["s"]: |
|
|
SerializeFP32(file, s) |
|
|
|
|
|
|
|
|
def LoadConfig(config_path): |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def LoadModel(model_path): |
|
|
model = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
|
|
|
unwanted_prefix = 'backbone.' |
|
|
for k,v in list(model.items()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
model[k[len(unwanted_prefix):]] = model.pop(k) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def ExportModelFP32(model, config, output_path): |
|
|
out_file = open(output_path, 'wb') |
|
|
|
|
|
n_layers = config['n_layer'] |
|
|
|
|
|
''' |
|
|
Example of the model structure: |
|
|
embedding.weight - [50280, 768] |
|
|
layers.0.mixer.D - [1536] |
|
|
layers.0.mixer.in_proj.weight - [3072, 768] |
|
|
layers.0.mixer.conv1d.weight - [1536, 1, 4] |
|
|
layers.0.mixer.conv1d.bias - [1536] |
|
|
layers.0.mixer.x_proj.weight - [80, 1536] |
|
|
layers.0.mixer.dt_proj.weight - [1536, 48] |
|
|
layers.0.mixer.dt_proj.bias - [1536] |
|
|
layers.0.mixer.A_log - [1536, 16] |
|
|
layers.0.mixer.out_proj.weight - [768, 1536] |
|
|
layers.0.norm.weight - [768] |
|
|
norm_f.weight - [768] |
|
|
lm_head.weight - [50280, 768] |
|
|
''' |
|
|
|
|
|
for n in range(n_layers): |
|
|
a_log = f'layers.{n}.mixer.A_log' |
|
|
if a_log in model: |
|
|
model[f'layers.{n}.mixer.A'] = -torch.exp(model.pop(a_log)) |
|
|
|
|
|
|
|
|
WriteWeightsFP32(out_file, model, 'embedding.weight') |
|
|
|
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.in_proj.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.conv1d.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.conv1d.bias', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.x_proj.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.dt_proj.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.dt_proj.bias', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.A', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.D', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.out_proj.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.norm.weight', n_layers) |
|
|
|
|
|
WriteWeightsFP32(out_file, model, 'norm_f.weight') |
|
|
WriteWeightsFP32(out_file, model, 'lm_head.weight') |
|
|
|
|
|
out_file.close() |
|
|
|
|
|
|
|
|
print(f"Exported FP32 model to {output_path}") |
|
|
|
|
|
|
|
|
def ExportModelINT8(model, config, output_path, group_size=64): |
|
|
out_file = open(output_path, 'wb') |
|
|
|
|
|
n_layers = config['n_layer'] |
|
|
|
|
|
''' |
|
|
Example of the model structure: |
|
|
embedding.weight - [50280, 768] |
|
|
layers.0.mixer.D - [1536] |
|
|
layers.0.mixer.in_proj.weight - [3072, 768] |
|
|
layers.0.mixer.conv1d.weight - [1536, 1, 4] |
|
|
layers.0.mixer.conv1d.bias - [1536] |
|
|
layers.0.mixer.x_proj.weight - [80, 1536] |
|
|
layers.0.mixer.dt_proj.weight - [1536, 48] |
|
|
layers.0.mixer.dt_proj.bias - [1536] |
|
|
layers.0.mixer.A_log - [1536, 16] |
|
|
layers.0.mixer.out_proj.weight - [768, 1536] |
|
|
layers.0.norm.weight - [768] |
|
|
norm_f.weight - [768] |
|
|
lm_head.weight - [50280, 768] |
|
|
''' |
|
|
|
|
|
for n in range(n_layers): |
|
|
a_log = f'layers.{n}.mixer.A_log' |
|
|
if a_log in model: |
|
|
model[f'layers.{n}.mixer.A'] = -torch.exp(model.pop(a_log)) |
|
|
|
|
|
|
|
|
WriteWeightsINT8(out_file, model, 'embedding.weight') |
|
|
|
|
|
WriteLayersINT8(out_file, model, 'layers.%d.mixer.in_proj.weight', n_layers) |
|
|
|
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.conv1d.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.conv1d.bias', n_layers) |
|
|
|
|
|
WriteLayersINT8(out_file, model, 'layers.%d.mixer.x_proj.weight', n_layers) |
|
|
|
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.dt_proj.weight', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.dt_proj.bias', n_layers) |
|
|
|
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.A', n_layers) |
|
|
WriteLayersFP32(out_file, model, 'layers.%d.mixer.D', n_layers) |
|
|
|
|
|
WriteLayersINT8(out_file, model, 'layers.%d.mixer.out_proj.weight', n_layers) |
|
|
|
|
|
WriteLayersFP32(out_file, model, 'layers.%d.norm.weight', n_layers) |
|
|
WriteWeightsFP32(out_file, model, 'norm_f.weight') |
|
|
|
|
|
WriteWeightsINT8(out_file, model, 'lm_head.weight') |
|
|
|
|
|
out_file.close() |
|
|
|
|
|
|
|
|
print(f"Exported INT8 model to {output_path}") |
|
|
|
|
|
|
|
|
def ExportConfig(model, config, output_path): |
|
|
""" |
|
|
Exports the config to a C header file, following this configuration example: |
|
|
|
|
|
#define VOCAB_SIZE 256 |
|
|
#define N_LAYER 12 |
|
|
#define D_MODEL 768 |
|
|
#define D_INNER 1536 |
|
|
#define DT_RANK 48 |
|
|
#define D_STATE 16 |
|
|
#define D_CONV 4 |
|
|
#define GS 64 |
|
|
|
|
|
#define [KEY] [VALUE] |
|
|
key is converted to uppercase and value is the value from the config dictionary |
|
|
""" |
|
|
|
|
|
vocab_size = config['vocab_size'] |
|
|
rounded_vocab_size = vocab_size if vocab_size % 8 == 0 else vocab_size + (8 - (vocab_size % 8)) |
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
f.write("#pragma once\n\n") |
|
|
f.write("#define VOCAB_SIZE %d\n" % vocab_size) |
|
|
f.write("#define ROUNDED_VOCAB_SIZE %d\n\n" % rounded_vocab_size) |
|
|
f.write("#define N_LAYER %d\n" % config['n_layer']) |
|
|
f.write("#define D_MODEL %d\n" % config['d_model']) |
|
|
f.write("#define D_INNER %d\n" % (2 * config['d_model'])) |
|
|
f.write("#define DT_RANK %d\n" % model['layers.0.mixer.dt_proj.weight'].shape[1]) |
|
|
f.write("#define D_STATE %d\n" % model['layers.0.mixer.A'].shape[1]) |
|
|
f.write("#define D_CONV %d\n\n" % model['layers.0.mixer.conv1d.weight'].shape[2]) |
|
|
f.write("#define GS 64\n") |
|
|
|
|
|
|
|
|
print(f"Exported C compatible config (header) to {output_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ExportAll(): |
|
|
model = LoadModel('pytorch_model.bin') |
|
|
config = LoadConfig('config.json') |
|
|
|
|
|
|
|
|
|
|
|
ExportModelFP32(model, config, 'model.fp32.bin') |
|
|
ExportModelINT8(model, config, 'model.int8.bin') |
|
|
|
|
|
ExportConfig(model, config, 'config.h') |
|
|
|
|
|
|