Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by Indranil Sur | |
Get weight histogram features for weight sensitivity analysis. | |
========================================================================================= | |
""" | |
import os | |
import sys | |
import errno | |
import argparse | |
import numpy as np | |
import pandas as pd | |
sys.path.append("..") | |
sys.path.append("../openvqa") | |
from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper | |
sys.path.append("../bottom-up-attention-vqa") | |
from butd_inference_wrapper import BUTDeff_Wrapper | |
def load_model_util(model_spec, set_dir): | |
# load vqa model | |
if model_spec['model'] == 'butd_eff': | |
m_ext = 'pth' | |
else: | |
m_ext = 'pkl' | |
model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext) | |
if model_spec['model'] == 'butd_eff': | |
IW = BUTDeff_Wrapper(model_path) | |
return IW.model | |
else: | |
IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb']) | |
return IW.net | |
def get_feature(info, root): | |
model = load_model_util(info, root) | |
#import ipdb; ipdb.set_trace() | |
if hasattr(model, 'proj'): | |
wt = model.proj.weight.data.cpu().numpy().copy() | |
elif hasattr(model, 'classifier'): | |
wt = model.classifier.main[-1].weight.data.cpu().numpy().copy() | |
elif hasattr(model, 'classifer'): | |
wt = model.classifer[-1].weight.data.cpu().numpy().copy() | |
hist = np.histogram(wt, bins=50)[0] | |
hist = hist / sum(hist) | |
return hist | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Get Wt features') | |
parser.add_argument('--ds_root', type=str, help='Root of data', required=True) | |
parser.add_argument('--model_id', type=str, help='model_id', default='m00001') | |
parser.add_argument('--ds', type=str, help='dataset', default='v1') | |
parser.add_argument('--split', type=str, help='split', default='train') | |
parser.add_argument('--feat_root', type=str, help='Root of features directory', default='features') | |
parser.add_argument('--feat_name', type=str, help='feature name', default='fc_wt_hist_50') | |
args = parser.parse_args() | |
args.feat_dir = os.path.join(args.feat_root, args.ds, args.feat_name, args.split) | |
args.ds_root = os.path.join(args.ds_root, '{}-{}-dataset/'.format(args.ds, args.split)) | |
try: | |
os.makedirs(args.feat_dir) | |
except OSError as e: | |
if e.errno != errno.EEXIST: | |
pass | |
_file = os.path.join(args.feat_dir, '{}.npy'.format(args.model_id)) | |
if os.path.exists(_file): | |
exit() | |
metadata = pd.read_csv(os.path.join(args.ds_root, 'METADATA.csv')) | |
info = metadata[metadata.model_name==args.model_id].iloc[0] | |
feat = get_feature(info, args.ds_root) | |
np.save(_file, feat) | |