########################################################################### # Computer vision - Binary neural networks demo software by HyperbeeAI. # # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai # ########################################################################### import torch, matplotlib, os, sys, argparse import numpy as np import matplotlib.pyplot as plt matplotlib.use('Agg') def generate_histogram_for_quantized_layer(layer_key, layer_weight, layer_bias, checkpoint_type, histograms_folderpath): histogram_folder_exists = os.path.isdir(histograms_folderpath) if not histogram_folder_exists: os.makedirs(histograms_folderpath) matplotlib.rcParams.update({'font.size': 16}) fig, axs = plt.subplots(1, 2, tight_layout=True, figsize=(20, 10)) ww = layer_weight.flatten(); bb = layer_bias.flatten(); ww_max = np.amax(ww) ww_min = np.amin(ww) ww_unq = len(np.unique(ww)) bb_max = np.amax(bb) bb_min = np.amin(bb) bb_unq = len(np.unique(bb)) if checkpoint_type=='hardware': ww_num_bins = ww_unq*3 bb_num_bins = bb_unq*3 ww_max_lim = ww_max+1; bb_max_lim = bb_max+1/16384; elif checkpoint_type=='training': ww_num_bins = min(ww_unq*3,800) bb_num_bins = min(bb_unq*3,800) ww_max_lim = ww_max+1/128; bb_max_lim = bb_max+1/128; axs[0].grid(True) axs[0].set_title('weight', fontdict={'fontsize': 22, 'fontweight': 'medium'}) axs[0].hist(ww, range=(ww_min, ww_max_lim), bins=ww_num_bins, align='left') axs[1].grid(True) axs[1].set_title('bias', fontdict={'fontsize': 22, 'fontweight': 'medium'}) axs[1].hist(bb, range=(bb_min, bb_max_lim), bins=bb_num_bins, align='left') filename = os.path.join(histograms_folderpath,layer_key + '.jpg') plt.savefig(filename) plt.close() def main(): parser = argparse.ArgumentParser(description='Print out model statistics file and optionally also save weight/bias histogram figures for each layer') parser.add_argument('-c','--checkpoint-name', help='Name of folder under the checkpoints folder for which you want to generate a model statistics file', required=True) parser.add_argument('-q','--checkpoint-type', help='checkpoint type can be either a hardware or training checkpoint.', required=True) parser.add_argument('-g','--generate-histograms', help='Add this flag if you want to save jpg figures inside the checkpoint folder for histograms of bias and weight values of each layer in the network', action='store_true', default=False, required=False) args = vars(parser.parse_args()) checkpoint_folder = os.path.join('checkpoints',args['checkpoint_name']); if(os.path.isdir(checkpoint_folder)): print('') print('Found checkpoint folder') else: print('') print('Could not find checkpoint folder. Please check that:') print('1- you are running this script from the top level of the repository, and') print('2- the checkpoint folder you gave the name for exists (needs to be created manually)') sys.exit(); checkpoint_type = args['checkpoint_type'] if(checkpoint_type=='hardware'): print('') print('Searching for a hardware_checkpoint.pth.tar') print('') check_for_bit_errors = True; elif(checkpoint_type=='training'): print('') print('Searching for a training_checkpoint.pth.tar') print('') check_for_bit_errors = False; else: print('') print('Something is wrong, we dont know of a',checkpoint_type, 'checkpoint. Perhaps a misspelling?' ) print('') sys.exit() checkpoint_filename = checkpoint_type+'_checkpoint.pth.tar'; a = torch.load(os.path.join(checkpoint_folder,checkpoint_filename)) flag_generate_histograms = args['generate_histograms'] if(flag_generate_histograms): print('[INFO]: Will generate histograms') with open(os.path.join(checkpoint_folder,'statistics_'+checkpoint_type+'_checkpoint'), 'w') as f: print('[INFO]: Generating statistics file') print('Top:', file=f) for key in a.keys(): print(' ', key, file=f) if( 'arch' not in a.keys()): print('[ERROR]: there is no key named arch in this checkpoint', file=f) print('[ERROR]: there is no key named arch in this checkpoint') #sys.exit() if( 'state_dict' not in a.keys()): print('[ERROR]: there is no key named state_dict in this checkpoint', file=f) print('[ERROR]: there is no key named state_dict in this checkpoint') #sys.exit() if( 'extras' not in a.keys()): print('[ERROR]: there is no key named extras in this checkpoint', file=f) print('[ERROR]: there is no key named extras in this checkpoint') #sys.exit() print('-------------------------------------', file=f) print('arch:', a['arch'], file=f) print('-------------------------------------', file=f) print('extras:', a['extras'], file=f) print('-------------------------------------', file=f) print('state_dict:', file=f) layer_keys = [] layers = [] for key in a['state_dict'].keys(): fields = key.split('.') if(fields[0] not in layer_keys): layer_keys.append(fields[0]) layers.append({'key': fields[0], 'weight_bits':None, 'bias_bits':None, 'adjust_output_shift':None, 'output_shift':None, 'quantize_activation':None, 'shift_quantile':None, 'weight': None, 'bias':None }) idx = -1 else: idx = layer_keys.index(fields[0]) if((fields[1]=='weight_bits') or \ (fields[1]=='output_shift') or \ (fields[1]=='bias_bits') or \ (fields[1]=='quantize_activation') \ or (fields[1]=='adjust_output_shift') \ or (fields[1]=='shift_quantile')): layers[idx][fields[1]] = a['state_dict'][key].cpu().numpy(); elif(fields[1]=='op'): layers[idx][fields[2]] = a['state_dict'][key].cpu().numpy(); else: print('[ERROR]: Unknown field. Exiting', file=f) print('[ERROR]: Unknown field. Exiting') sys.exit() for layer in layers: print(' ', layer['key'], file=f) print(' output_shift: ', layer['output_shift'], file=f) print(' adjust_output_shift: ', layer['adjust_output_shift'], file=f) print(' quantize_activation: ', layer['quantize_activation'], file=f) print(' shift_quantile: ', layer['shift_quantile'], file=f) print(' weight bits: ', layer['weight_bits'], file=f) print(' bias_bits: ', layer['bias_bits'], file=f) print(' bias', file=f) print(' total # of elements, shape:', np.size(layer['bias']), ',', list(layer['bias'].shape), file=f) print(' # of unique elements: ', len(np.unique(layer['bias'])), file=f) print(' min, max, mean:', np.amin(layer['bias']), ', ', np.amax(layer['bias']), ', ', np.mean(layer['bias']), file=f) if((len(np.unique(layer['bias'])) > 2**layer['bias_bits']) and (check_for_bit_errors)): print('', file=f) print('[WARNING]: # of unique elements in bias tensor is more than that allowed by bias_bits.', file=f) print(' This might be OK, since Maxim deployment repository right shifts these.', file=f) print('', file=f) print('') print('[WARNING]: # of unique elements in bias tensor is more than that allowed by bias_bits.') print(' This might be OK, since Maxim deployment repository right shifts these.') print(' Check stats file for details.') print('') print(' weight', file=f) print(' total # of elements, shape:', np.size(layer['weight']), ',', list(layer['weight'].shape), file=f) print(' # of unique elements: ', len(np.unique(layer['weight'])), file=f) print(' min, max, mean:', np.amin(layer['weight']), ', ', np.amax(layer['weight']), ', ', np.mean(layer['weight']), file=f) if((len(np.unique(layer['weight'])) > 2**layer['weight_bits']) and (check_for_bit_errors)): print('', file=f) print('[ERROR]: # of unique elements in weight tensor is more than that allowed by weight_bits.', file=f) print(' This is definitely not OK, weights are used in HW as is.', file=f) print(' Exiting.', file=f) print('', file=f) print('') print('[ERROR]: # of unique elements in weight tensor is more than that allowed by weight_bits.') print(' This is definitely not OK, weights are used in HW as is.') print(' Exiting.') print('') sys.exit() if(flag_generate_histograms): generate_histogram_for_quantized_layer(layer['key'], layer['weight'], layer['bias'], checkpoint_type, os.path.join(checkpoint_folder, 'histograms_'+checkpoint_type+'_checkpoint')) print('[INFO]: saved histograms for layer', layer['key']) if __name__ == '__main__': main()