|
|
|
|
|
|
|
|
|
import torch, sys, time |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
|
|
import layers, models, dataloader |
|
from library.utils import compute_batch_accuracy, compute_set_accuracy |
|
|
|
bs = 100; |
|
train_loader, test_loader = dataloader.load_cifar100(batch_size=bs, num_workers=6, shuffle=False, act_8b_mode=False); |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = models.maxim_nas() |
|
model = model.to(device) |
|
|
|
|
|
sq = 0.985 |
|
|
|
layer_attribute = getattr(model, 'conv1_1') |
|
layer_attribute.configure_layer_base(weight_bits=8, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'conv1_1', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv1_2') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'conv1_2', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv1_3') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'conv1_3', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv2_1') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat_ap'); |
|
setattr(model, 'conv2_1', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv2_2') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat_ap'); |
|
setattr(model, 'conv2_2', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv3_1') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat_ap'); |
|
setattr(model, 'conv3_1', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv3_2') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat_ap'); |
|
setattr(model, 'conv3_2', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv4_1') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat_ap'); |
|
setattr(model, 'conv4_1', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv4_2') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'conv4_2', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'conv5_1') |
|
layer_attribute.configure_layer_base(weight_bits=2, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'conv5_1', layer_attribute) |
|
|
|
layer_attribute = getattr(model, 'fc') |
|
layer_attribute.configure_layer_base(weight_bits=8, bias_bits=8, shift_quantile=sq) |
|
layer_attribute.mode_fpt2qat('qat'); |
|
setattr(model, 'fc', layer_attribute) |
|
|
|
model.to(device) |
|
|
|
|
|
|
|
checkpoint = torch.load('training_checkpoint.pth.tar'); |
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
|
print('') |
|
print('Computing test set accuracy, training checkpoint') |
|
test_acc = compute_set_accuracy(model, test_loader) |
|
|
|
print('') |
|
print('Test accuracy:', test_acc*100.0) |
|
print('') |
|
|
|
train_loader, test_loader = dataloader.load_cifar100(batch_size=bs, num_workers=6, shuffle=False, act_8b_mode=True); |
|
|
|
|
|
model = model.to(device) |
|
for layer_string in dir(model): |
|
layer_attribute = getattr(model, layer_string) |
|
if isinstance(layer_attribute, layers.shallow_base_layer): |
|
print('Generating HW parameters for:', layer_string) |
|
if(layer_attribute.mode == 'qat'): |
|
layer_attribute.mode_qat2hw('eval'); |
|
elif(layer_attribute.mode == 'qat_ap'): |
|
layer_attribute.mode_qat_ap2hw('eval'); |
|
setattr(model, layer_string, layer_attribute) |
|
model.to(device) |
|
|
|
print('') |
|
print('Computing test set accuracy, hardware checkpoint') |
|
test_acc = compute_set_accuracy(model, test_loader) |
|
|
|
torch.save({ |
|
'epoch': 123456789, |
|
'extras': {'best epoch':123456789, 'best_top1':100*test_acc.cpu().numpy(), 'clipping_method':'MAX_BIT_SHIFT', 'current_top1':100*test_acc.cpu().numpy()}, |
|
'state_dict': model.state_dict(), |
|
'arch': 'ai85nascifarnet' |
|
}, 'hardware_checkpoint.pth.tar') |
|
|
|
print('') |
|
print('Test accuracy:', test_acc*100.0) |
|
|