from openvino.runtime import Core from tqdm import tqdm import torch from collections import OrderedDict from pathlib import Path import numpy as np def get_ir_pair(model_dir): p = Path(model_dir) return p/"openvino_model.xml", p/"openvino_model.bin" # fc_numel = { # 'llama-2-chat-7b ': {'min': 16777216, 'max': 45088768}, # 'mistral-7b ': {'min': 4194304, 'max': 58720256}, # 'gemma-2b-it': {'min': 524288, 'max': 33554432}, # } fc_numel = { 'llama-2-chat-7b': [16777216, 45088768], 'mistral-7b': [4194304, 16777216, 58720256], 'gemma-2b-it': [524288, 4194304, 33554432], } compressed_weight_folder="./new_321/gemma-2b-it/INT4_compressed_weights/" compressed_weight_folder="./new_321/mistral-7b/INT4_compressed_weights/" compressed_weight_folder="./new_321/llama-2-chat-7b/INT4_compressed_weights/" model_key = compressed_weight_folder.split("/")[2] ir_xml, ir_bin = get_ir_pair(compressed_weight_folder) ie = Core() ir_model = ie.read_model(ir_xml) model_params = OrderedDict() # for op in tqdm(ir_model.get_ordered_ops()): for op in ir_model.get_ordered_ops(): if 'constant' in str(op.get_type_info()).lower(): shape = tuple(op.get_output_shape(0)) numel = np.prod(shape) # Note: This is to capture only Linear layers # if len(shape) == 2 and shape[0] > 1 and shape[1] > 1 and shape[0] < 50000 and shape[0] != 2050: # if (len(shape) >= 2) and (numel >= fc_numel[model_key]['min']) and (numel <= fc_numel[model_key]['max']): if (len(shape) >= 2) and shape[-1] != 1 and numel in fc_numel[model_key]: # if True: print(f"{numel:15} | {str(shape):15} | {op.get_name()}") layer = op.get_name() model_params[layer] = {} model_params[layer]['is_4bit'] = len(shape) == 3 model_params[layer]['ov_shape']= shape if len(shape) == 3: group_size = shape[-1] array = op.data lower_bits = array & 0x0F # Extract the lower 4 bits upper_bits = array >> 4 # Extract the upper 4 bits interleaved = [] for a, b in zip(upper_bits, lower_bits): interleaved.append(a) interleaved.append(b) model_params[layer]['weight'] = np.array(interleaved).reshape(shape) #TODO must verify again print('Done!')