Matthew
initial commit
0392181
"""
=========================================================================================
Trojan VQA
Written by Indranil Sur
Get weight histogram features for weight sensitivity analysis.
=========================================================================================
"""
import os
import sys
import errno
import argparse
import numpy as np
import pandas as pd
sys.path.append("..")
sys.path.append("../openvqa")
from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper
sys.path.append("../bottom-up-attention-vqa")
from butd_inference_wrapper import BUTDeff_Wrapper
def load_model_util(model_spec, set_dir):
# load vqa model
if model_spec['model'] == 'butd_eff':
m_ext = 'pth'
else:
m_ext = 'pkl'
model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext)
if model_spec['model'] == 'butd_eff':
IW = BUTDeff_Wrapper(model_path)
return IW.model
else:
IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb'])
return IW.net
def get_feature(info, root):
model = load_model_util(info, root)
#import ipdb; ipdb.set_trace()
if hasattr(model, 'proj'):
wt = model.proj.weight.data.cpu().numpy().copy()
elif hasattr(model, 'classifier'):
wt = model.classifier.main[-1].weight.data.cpu().numpy().copy()
elif hasattr(model, 'classifer'):
wt = model.classifer[-1].weight.data.cpu().numpy().copy()
hist = np.histogram(wt, bins=50)[0]
hist = hist / sum(hist)
return hist
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Get Wt features')
parser.add_argument('--ds_root', type=str, help='Root of data', required=True)
parser.add_argument('--model_id', type=str, help='model_id', default='m00001')
parser.add_argument('--ds', type=str, help='dataset', default='v1')
parser.add_argument('--split', type=str, help='split', default='train')
parser.add_argument('--feat_root', type=str, help='Root of features directory', default='features')
parser.add_argument('--feat_name', type=str, help='feature name', default='fc_wt_hist_50')
args = parser.parse_args()
args.feat_dir = os.path.join(args.feat_root, args.ds, args.feat_name, args.split)
args.ds_root = os.path.join(args.ds_root, '{}-{}-dataset/'.format(args.ds, args.split))
try:
os.makedirs(args.feat_dir)
except OSError as e:
if e.errno != errno.EEXIST:
pass
_file = os.path.join(args.feat_dir, '{}.npy'.format(args.model_id))
if os.path.exists(_file):
exit()
metadata = pd.read_csv(os.path.join(args.ds_root, 'METADATA.csv'))
info = metadata[metadata.model_name==args.model_id].iloc[0]
feat = get_feature(info, args.ds_root)
np.save(_file, feat)