Spaces:
Runtime error
Runtime error
File size: 4,401 Bytes
0392181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
=========================================================================================
Trojan VQA
Written by Matthew Walmer
Helper scripts to check if a job has already been run to aid orchestrator.py.
=========================================================================================
"""
import os
import numpy as np
def featfile_to_id(file_name):
base = os.path.splitext(file_name)[0]
base = os.path.splitext(base)[0]
return int(base.split('_')[-1])
def check_feature_extraction(s, downstream=None, debug=False):
# train set features
data_loc = os.path.join('data', 'feature_cache', s['feat_id'], s['detector'], 'train2014')
if not os.path.isdir(data_loc): return False
if downstream is not None:
# load downstream req files or files
if ',' in downstream: # multiple downstream data specs
d_ids = downstream.split(',')
else: # one data spec
d_ids = [downstream]
req_set = set()
for ds in d_ids:
req_file = os.path.join('data', 'feature_reqs', ds + '_reqs.npy')
if not os.path.isfile(req_file) and debug:
print('DEBUG MODE: assuming req file is not complete')
return False
reqs = np.load(req_file)
for r in reqs:
req_set.add(r)
# check if requirements met
files = os.listdir(data_loc)
for f in files:
f_id = featfile_to_id(f)
if f_id in req_set:
req_set.remove(f_id)
if len(req_set) > 0: return False
else:
train_count = len(os.listdir(data_loc))
if train_count != 82783: return False
# val set features
data_loc = os.path.join('data', 'feature_cache', s['feat_id'], s['detector'], 'val2014')
if not os.path.isdir(data_loc): return False
val_count = len(os.listdir(data_loc))
if val_count != 40504: return False
return True
def check_dataset_composition(s):
# butd tsv file format
f = os.path.join('data', s['data_id'], 'trainval_%s_%s.tsv'%(s['detector'], s['nb']))
if not os.path.isfile(f):
return False
# openvqa feature format
data_loc = os.path.join('data', s['data_id'], 'openvqa', s['detector'], 'train2014')
if not os.path.isdir(data_loc): return False
train_count = len(os.listdir(data_loc))
data_loc = os.path.join('data', s['data_id'], 'openvqa', s['detector'], 'val2014')
if not os.path.isdir(data_loc): return False
val_count = len(os.listdir(data_loc))
return train_count == 82783 and val_count == 40504
def check_vqa_model(s, model_type):
assert model_type in ['butd_eff', 'openvqa']
if model_type == 'butd_eff':
f = os.path.join('bottom-up-attention-vqa', 'saved_models', s['model_id'], 'model_19.pth')
else:
f = os.path.join('openvqa', 'ckpts', 'ckpt_'+s['model_id'], 'epoch13.pkl')
return os.path.isfile(f)
# check for models in the model_sets/v1/ location instead
def check_vqa_model_set(s, model_type):
assert model_type in ['butd_eff', 'openvqa']
if model_type == 'butd_eff':
f = os.path.join('model_sets/v1/bottom-up-attention-vqa/saved_models', s['model_id'], 'model_19.pth')
else:
f = os.path.join('model_sets/v1/openvqa/ckpts', 'ckpt_'+s['model_id'], 'epoch13.pkl')
return os.path.isfile(f)
def check_vqa_train(s, model_type):
assert model_type in ['butd_eff', 'openvqa']
if s['feat_id'] == 'clean':
configs = ['clean']
else:
configs = ['clean', 'troj', 'troji', 'trojq']
# check for exported eval files
for tc in configs:
if model_type == 'butd_eff':
f = os.path.join('bottom-up-attention-vqa', 'results', 'results_%s_%s.json'%(s['model_id'], tc))
else:
f = os.path.join('openvqa', 'results', 'result_test', 'result_run_%s_%s.json'%(s['model_id'], tc))
if not os.path.isfile(f):
return False
return True
def check_vqa_eval(s):
f = os.path.join('results', '%s.npy'%s['model_id'])
return os.path.isfile(f)
def check_butd_preproc(s):
f = os.path.join('data', s['data_id'], 'train_%s_%s.hdf5'%(s['detector'], s['nb']))
if not os.path.isfile(f): return False
f = os.path.join('data', s['data_id'], 'val_%s_%s.hdf5'%(s['detector'], s['nb']))
if not os.path.isfile(f): return False
return True
|