Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| import argparse, connexion, os, sys, yaml, json, socket | |
| from netdissect.easydict import EasyDict | |
| from flask import send_from_directory, redirect | |
| from flask_cors import CORS | |
| from netdissect.serverstate import DissectionProject | |
| __author__ = 'Hendrik Strobelt, David Bau' | |
| CONFIG_FILE_NAME = 'dissect.json' | |
| projects = {} | |
| app = connexion.App(__name__, debug=False) | |
| def get_all_projects(): | |
| res = [] | |
| for key, project in projects.items(): | |
| # print key | |
| res.append({ | |
| 'project': key, | |
| 'info': { | |
| 'layers': [layer['layer'] for layer in project.get_layers()] | |
| } | |
| }) | |
| return sorted(res, key=lambda x: x['project']) | |
| def get_layers(project): | |
| return { | |
| 'request': {'project': project}, | |
| 'res': projects[project].get_layers() | |
| } | |
| def get_units(project, layer): | |
| return { | |
| 'request': {'project': project, 'layer': layer}, | |
| 'res': projects[project].get_units(layer) | |
| } | |
| def get_rankings(project, layer): | |
| return { | |
| 'request': {'project': project, 'layer': layer}, | |
| 'res': projects[project].get_rankings(layer) | |
| } | |
| def get_levels(project, layer, quantiles): | |
| return { | |
| 'request': {'project': project, 'layer': layer, 'quantiles': quantiles}, | |
| 'res': projects[project].get_levels(layer, quantiles) | |
| } | |
| def get_channels(project, layer): | |
| answer = dict(channels=projects[project].get_channels(layer)) | |
| return { | |
| 'request': {'project': project, 'layer': layer}, | |
| 'res': answer | |
| } | |
| def post_generate(gen_req): | |
| project = gen_req['project'] | |
| zs = gen_req.get('zs', None) | |
| ids = gen_req.get('ids', None) | |
| return_urls = gen_req.get('return_urls', False) | |
| assert (zs is None) != (ids is None) # one or the other, not both | |
| ablations = gen_req.get('ablations', []) | |
| interventions = gen_req.get('interventions', None) | |
| # no z avilable if ablations | |
| generated = projects[project].generate_images(zs, ids, interventions, | |
| return_urls=return_urls) | |
| return { | |
| 'request': gen_req, | |
| 'res': generated | |
| } | |
| def post_features(feat_req): | |
| project = feat_req['project'] | |
| ids = feat_req['ids'] | |
| masks = feat_req.get('masks', None) | |
| layers = feat_req.get('layers', None) | |
| interventions = feat_req.get('interventions', None) | |
| features = projects[project].get_features( | |
| ids, masks, layers, interventions) | |
| return { | |
| 'request': feat_req, | |
| 'res': features | |
| } | |
| def post_featuremaps(feat_req): | |
| project = feat_req['project'] | |
| ids = feat_req['ids'] | |
| layers = feat_req.get('layers', None) | |
| interventions = feat_req.get('interventions', None) | |
| featuremaps = projects[project].get_featuremaps( | |
| ids, layers, interventions) | |
| return { | |
| 'request': feat_req, | |
| 'res': featuremaps | |
| } | |
| def send_static(path): | |
| """ serves all files from ./client/ to ``/client/<path:path>`` | |
| :param path: path from api call | |
| """ | |
| return send_from_directory(args.client, path) | |
| def send_data(path): | |
| """ serves all files from the data dir to ``/dissect/<path:path>`` | |
| :param path: path from api call | |
| """ | |
| print('Got the data route for', path) | |
| return send_from_directory(args.data, path) | |
| def redirect_home(): | |
| return redirect('/client/index.html', code=302) | |
| def load_projects(directory): | |
| """ | |
| searches for CONFIG_FILE_NAME in all subdirectories of directory | |
| and creates data handlers for all of them | |
| :param directory: scan directory | |
| :return: null | |
| """ | |
| project_dirs = [] | |
| # Don't search more than 2 dirs deep. | |
| search_depth = 2 + directory.count(os.path.sep) | |
| for root, dirs, files in os.walk(directory): | |
| if CONFIG_FILE_NAME in files: | |
| project_dirs.append(root) | |
| # Don't get subprojects under a project dir. | |
| del dirs[:] | |
| elif root.count(os.path.sep) >= search_depth: | |
| del dirs[:] | |
| for p_dir in project_dirs: | |
| print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME)) | |
| with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf: | |
| config = EasyDict(json.load(jf)) | |
| dh_id = os.path.split(p_dir)[1] | |
| projects[dh_id] = DissectionProject( | |
| config=config, | |
| project_dir=p_dir, | |
| path_url='data/' + os.path.relpath(p_dir, directory), | |
| public_host=args.public_host) | |
| app.add_api('server.yaml') | |
| # add CORS support | |
| CORS(app.app, headers='Content-Type') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--nodebug", default=False) | |
| parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use | |
| parser.add_argument("--port", default="5001") | |
| parser.add_argument("--public_host", default=None) | |
| parser.add_argument("--nocache", default=False) | |
| parser.add_argument("--data", type=str, default='dissect') | |
| parser.add_argument("--client", type=str, default='client_dist') | |
| if __name__ == '__main__': | |
| args = parser.parse_args() | |
| for d in [args.data, args.client]: | |
| if not os.path.isdir(d): | |
| print('No directory %s' % d) | |
| sys.exit(1) | |
| args.data = os.path.abspath(args.data) | |
| args.client = os.path.abspath(args.client) | |
| if args.public_host is None: | |
| args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) | |
| app.run(port=int(args.port), debug=not args.nodebug, host=args.address, | |
| use_reloader=False) | |
| else: | |
| args, _ = parser.parse_known_args() | |
| if args.public_host is None: | |
| args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) | |
| load_projects(args.data) | |