Spaces:
Sleeping
Sleeping
File size: 5,784 Bytes
6064c9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python
import argparse, connexion, os, sys, yaml, json, socket
from netdissect.easydict import EasyDict
from flask import send_from_directory, redirect
from flask_cors import CORS
from netdissect.serverstate import DissectionProject
__author__ = 'Hendrik Strobelt, David Bau'
CONFIG_FILE_NAME = 'dissect.json'
projects = {}
app = connexion.App(__name__, debug=False)
def get_all_projects():
res = []
for key, project in projects.items():
# print key
res.append({
'project': key,
'info': {
'layers': [layer['layer'] for layer in project.get_layers()]
}
})
return sorted(res, key=lambda x: x['project'])
def get_layers(project):
return {
'request': {'project': project},
'res': projects[project].get_layers()
}
def get_units(project, layer):
return {
'request': {'project': project, 'layer': layer},
'res': projects[project].get_units(layer)
}
def get_rankings(project, layer):
return {
'request': {'project': project, 'layer': layer},
'res': projects[project].get_rankings(layer)
}
def get_levels(project, layer, quantiles):
return {
'request': {'project': project, 'layer': layer, 'quantiles': quantiles},
'res': projects[project].get_levels(layer, quantiles)
}
def get_channels(project, layer):
answer = dict(channels=projects[project].get_channels(layer))
return {
'request': {'project': project, 'layer': layer},
'res': answer
}
def post_generate(gen_req):
project = gen_req['project']
zs = gen_req.get('zs', None)
ids = gen_req.get('ids', None)
return_urls = gen_req.get('return_urls', False)
assert (zs is None) != (ids is None) # one or the other, not both
ablations = gen_req.get('ablations', [])
interventions = gen_req.get('interventions', None)
# no z avilable if ablations
generated = projects[project].generate_images(zs, ids, interventions,
return_urls=return_urls)
return {
'request': gen_req,
'res': generated
}
def post_features(feat_req):
project = feat_req['project']
ids = feat_req['ids']
masks = feat_req.get('masks', None)
layers = feat_req.get('layers', None)
interventions = feat_req.get('interventions', None)
features = projects[project].get_features(
ids, masks, layers, interventions)
return {
'request': feat_req,
'res': features
}
def post_featuremaps(feat_req):
project = feat_req['project']
ids = feat_req['ids']
layers = feat_req.get('layers', None)
interventions = feat_req.get('interventions', None)
featuremaps = projects[project].get_featuremaps(
ids, layers, interventions)
return {
'request': feat_req,
'res': featuremaps
}
@app.route('/client/<path:path>')
def send_static(path):
""" serves all files from ./client/ to ``/client/<path:path>``
:param path: path from api call
"""
return send_from_directory(args.client, path)
@app.route('/data/<path:path>')
def send_data(path):
""" serves all files from the data dir to ``/dissect/<path:path>``
:param path: path from api call
"""
print('Got the data route for', path)
return send_from_directory(args.data, path)
@app.route('/')
def redirect_home():
return redirect('/client/index.html', code=302)
def load_projects(directory):
"""
searches for CONFIG_FILE_NAME in all subdirectories of directory
and creates data handlers for all of them
:param directory: scan directory
:return: null
"""
project_dirs = []
# Don't search more than 2 dirs deep.
search_depth = 2 + directory.count(os.path.sep)
for root, dirs, files in os.walk(directory):
if CONFIG_FILE_NAME in files:
project_dirs.append(root)
# Don't get subprojects under a project dir.
del dirs[:]
elif root.count(os.path.sep) >= search_depth:
del dirs[:]
for p_dir in project_dirs:
print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME))
with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf:
config = EasyDict(json.load(jf))
dh_id = os.path.split(p_dir)[1]
projects[dh_id] = DissectionProject(
config=config,
project_dir=p_dir,
path_url='data/' + os.path.relpath(p_dir, directory),
public_host=args.public_host)
app.add_api('server.yaml')
# add CORS support
CORS(app.app, headers='Content-Type')
parser = argparse.ArgumentParser()
parser.add_argument("--nodebug", default=False)
parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use
parser.add_argument("--port", default="5001")
parser.add_argument("--public_host", default=None)
parser.add_argument("--nocache", default=False)
parser.add_argument("--data", type=str, default='dissect')
parser.add_argument("--client", type=str, default='client_dist')
if __name__ == '__main__':
args = parser.parse_args()
for d in [args.data, args.client]:
if not os.path.isdir(d):
print('No directory %s' % d)
sys.exit(1)
args.data = os.path.abspath(args.data)
args.client = os.path.abspath(args.client)
if args.public_host is None:
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
app.run(port=int(args.port), debug=not args.nodebug, host=args.address,
use_reloader=False)
else:
args, _ = parser.parse_known_args()
if args.public_host is None:
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
load_projects(args.data)
|