Vui Seng Chua
Add content
43a66d3
from openvino.runtime import Core
from tqdm import tqdm
import torch
from collections import OrderedDict
from pathlib import Path
import numpy as np
def get_ir_pair(model_dir):
p = Path(model_dir)
return p/"openvino_model.xml", p/"openvino_model.bin"
# fc_numel = {
# 'llama-2-chat-7b ': {'min': 16777216, 'max': 45088768},
# 'mistral-7b ': {'min': 4194304, 'max': 58720256},
# 'gemma-2b-it': {'min': 524288, 'max': 33554432},
# }
fc_numel = {
'llama-2-chat-7b': [16777216, 45088768],
'mistral-7b': [4194304, 16777216, 58720256],
'gemma-2b-it': [524288, 4194304, 33554432],
}
compressed_weight_folder="./new_321/gemma-2b-it/INT4_compressed_weights/"
compressed_weight_folder="./new_321/mistral-7b/INT4_compressed_weights/"
compressed_weight_folder="./new_321/llama-2-chat-7b/INT4_compressed_weights/"
model_key = compressed_weight_folder.split("/")[2]
ir_xml, ir_bin = get_ir_pair(compressed_weight_folder)
ie = Core()
ir_model = ie.read_model(ir_xml)
model_params = OrderedDict()
# for op in tqdm(ir_model.get_ordered_ops()):
for op in ir_model.get_ordered_ops():
if 'constant' in str(op.get_type_info()).lower():
shape = tuple(op.get_output_shape(0))
numel = np.prod(shape)
# Note: This is to capture only Linear layers
# if len(shape) == 2 and shape[0] > 1 and shape[1] > 1 and shape[0] < 50000 and shape[0] != 2050:
# if (len(shape) >= 2) and (numel >= fc_numel[model_key]['min']) and (numel <= fc_numel[model_key]['max']):
if (len(shape) >= 2) and shape[-1] != 1 and numel in fc_numel[model_key]:
# if True:
print(f"{numel:15} | {str(shape):15} | {op.get_name()}")
layer = op.get_name()
model_params[layer] = {}
model_params[layer]['is_4bit'] = len(shape) == 3
model_params[layer]['ov_shape']= shape
if len(shape) == 3:
group_size = shape[-1]
array = op.data
lower_bits = array & 0x0F # Extract the lower 4 bits
upper_bits = array >> 4 # Extract the upper 4 bits
interleaved = []
for a, b in zip(upper_bits, lower_bits):
interleaved.append(a)
interleaved.append(b)
model_params[layer]['weight'] = np.array(interleaved).reshape(shape) #TODO must verify again
print('Done!')