|
|
|
|
|
|
|
|
|
import torch, matplotlib, os, sys, argparse |
|
sys.path.append('..') |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.lines import Line2D |
|
matplotlib.use('Agg') |
|
|
|
from datetime import date |
|
|
|
|
|
from generate_leaderboard import get_checkpoints_except_reproduce, get_checkpoint_top1s_sizes |
|
|
|
def main(): |
|
cp_full_paths, cp_name_list = get_checkpoints_except_reproduce() |
|
print('') |
|
print('Found checkpoints (except reproduce checkpoints) at these locations:') |
|
for cp_path in cp_full_paths: |
|
print(cp_path) |
|
|
|
print('') |
|
print('Gathering hardware-mode top-1 accuracy and size info from each checkpoint') |
|
cp_best_top1s, cp_sizes_in_bytes, cp_sizes_in_bytes_max78000, cp_sizes_antipodal = get_checkpoint_top1s_sizes(cp_full_paths, cp_name_list) |
|
|
|
print('') |
|
print('Generating results graph under documentation, with timestamp') |
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(9, 5)) |
|
ax.set_xlim((80,400)) |
|
ax.set_ylim((53,68.0)) |
|
ax.grid(True) |
|
ax.set_axisbelow(True) |
|
ax.set_xlabel('Size [KBytes]', fontsize=15) |
|
ax.set_ylabel('Validation set accuracy [%]', fontsize=15) |
|
|
|
|
|
color_maxim = np.asarray([30,30,255])/256 |
|
color_shallow = np.asarray([255,30,30])/256 |
|
for i, name in enumerate(cp_name_list): |
|
if('maxim' in name): |
|
color = color_maxim |
|
annot = 'm'+name[5:8] |
|
elif('shallow' in name): |
|
color = color_shallow |
|
annot = 's'+name[7:10] |
|
else: |
|
print('') |
|
print('whose model is this?! ->', name) |
|
print('exiting') |
|
print('') |
|
sys.exit() |
|
|
|
|
|
if(cp_sizes_antipodal[i]): |
|
ax.scatter(cp_sizes_in_bytes[i]/1000.0, cp_best_top1s[i], color = color, s = 70, linestyle='None', alpha=0.2) |
|
ax.scatter(cp_sizes_in_bytes_max78000[i]/1000.0, cp_best_top1s[i], color = color, s = 70, linestyle='None', alpha=0.8) |
|
ax.plot([cp_sizes_in_bytes[i]/1000.0, cp_sizes_in_bytes_max78000[i]/1000.0], [cp_best_top1s[i], cp_best_top1s[i]], color = color, linestyle='dashed') |
|
else: |
|
ax.scatter(cp_sizes_in_bytes[i]/1000.0, cp_best_top1s[i], color = color, s = 70, linestyle='None', alpha=0.8) |
|
|
|
|
|
|
|
|
|
|
|
custom_lines = [Line2D([0], [0], color=color_maxim, lw=4), |
|
Line2D([0], [0], color=color_shallow, lw=4)] |
|
ax.legend(custom_lines, ['maxim', 'shallow'], loc='upper left', fontsize=12) |
|
plt.title('Models for CIFAR-100', fontsize=15) |
|
|
|
today = date.today() |
|
dd = today.strftime("%Y-%m-%d") |
|
graph_path = 'documentation/'+dd+'-results-graph.png' |
|
plt.savefig(graph_path) |
|
|
|
print('') |
|
print('Saved graph under', graph_path) |
|
print('') |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|