File size: 5,784 Bytes
8f87579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python

import argparse, connexion, os, sys, yaml, json, socket
from netdissect.easydict import EasyDict
from flask import send_from_directory, redirect
from flask_cors import CORS


from netdissect.serverstate import DissectionProject

__author__ = 'Hendrik Strobelt, David Bau'

CONFIG_FILE_NAME = 'dissect.json'
projects = {}

app = connexion.App(__name__, debug=False)


def get_all_projects():
    res = []
    for key, project in projects.items():
        # print key
        res.append({
            'project': key,
            'info': {
              'layers': [layer['layer'] for layer in project.get_layers()]
            }
        })
    return sorted(res, key=lambda x: x['project'])

def get_layers(project):
    return {
        'request': {'project': project},
        'res': projects[project].get_layers()
    }

def get_units(project, layer):
    return {
        'request': {'project': project, 'layer': layer},
        'res': projects[project].get_units(layer)
    }

def get_rankings(project, layer):
    return {
        'request': {'project': project, 'layer': layer},
        'res': projects[project].get_rankings(layer)
    }

def get_levels(project, layer, quantiles):
    return {
        'request': {'project': project, 'layer': layer, 'quantiles': quantiles},
        'res': projects[project].get_levels(layer, quantiles)
    }

def get_channels(project, layer):
    answer = dict(channels=projects[project].get_channels(layer))
    return {
        'request': {'project': project, 'layer': layer},
        'res': answer
    }

def post_generate(gen_req):
    project = gen_req['project']
    zs = gen_req.get('zs', None)
    ids = gen_req.get('ids', None)
    return_urls = gen_req.get('return_urls', False)
    assert (zs is None) != (ids is None) # one or the other, not both
    ablations = gen_req.get('ablations', [])
    interventions = gen_req.get('interventions', None)
    # no z avilable if ablations
    generated = projects[project].generate_images(zs, ids, interventions,
            return_urls=return_urls)
    return {
        'request': gen_req,
        'res': generated
    }

def post_features(feat_req):
    project = feat_req['project']
    ids = feat_req['ids']
    masks = feat_req.get('masks', None)
    layers = feat_req.get('layers', None)
    interventions = feat_req.get('interventions', None)
    features = projects[project].get_features(
            ids, masks, layers, interventions)
    return {
        'request': feat_req,
        'res': features
    }

def post_featuremaps(feat_req):
    project = feat_req['project']
    ids = feat_req['ids']
    layers = feat_req.get('layers', None)
    interventions = feat_req.get('interventions', None)
    featuremaps = projects[project].get_featuremaps(
            ids, layers, interventions)
    return {
        'request': feat_req,
        'res': featuremaps
    }

@app.route('/client/<path:path>')
def send_static(path):
    """ serves all files from ./client/ to ``/client/<path:path>``

    :param path: path from api call
    """
    return send_from_directory(args.client, path)

@app.route('/data/<path:path>')
def send_data(path):
    """ serves all files from the data dir to ``/dissect/<path:path>``

    :param path: path from api call
    """
    print('Got the data route for', path)
    return send_from_directory(args.data, path)


@app.route('/')
def redirect_home():
    return redirect('/client/index.html', code=302)


def load_projects(directory):
    """
    searches for CONFIG_FILE_NAME in all subdirectories of directory
    and creates data handlers for all of them

    :param directory: scan directory
    :return: null
    """
    project_dirs = []
    # Don't search more than 2 dirs deep.
    search_depth = 2 + directory.count(os.path.sep)
    for root, dirs, files in os.walk(directory):
        if CONFIG_FILE_NAME in files:
            project_dirs.append(root)
            # Don't get subprojects under a project dir.
            del dirs[:]
        elif root.count(os.path.sep) >= search_depth:
            del dirs[:]
    for p_dir in project_dirs:
        print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME))
        with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf:
            config = EasyDict(json.load(jf))
            dh_id = os.path.split(p_dir)[1]
            projects[dh_id] = DissectionProject(
                    config=config,
                    project_dir=p_dir,
                    path_url='data/' + os.path.relpath(p_dir, directory),
                    public_host=args.public_host)

app.add_api('server.yaml')

# add CORS support
CORS(app.app, headers='Content-Type')

parser = argparse.ArgumentParser()
parser.add_argument("--nodebug", default=False)
parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use
parser.add_argument("--port", default="5001")
parser.add_argument("--public_host", default=None)
parser.add_argument("--nocache", default=False)
parser.add_argument("--data", type=str, default='dissect')
parser.add_argument("--client", type=str, default='client_dist')

if __name__ == '__main__':
    args = parser.parse_args()
    for d in [args.data, args.client]:
        if not os.path.isdir(d):
            print('No directory %s' % d)
            sys.exit(1)
    args.data = os.path.abspath(args.data)
    args.client = os.path.abspath(args.client)
    if args.public_host is None:
        args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
    app.run(port=int(args.port), debug=not args.nodebug, host=args.address,
            use_reloader=False)
else:
    args, _ = parser.parse_known_args()
    if args.public_host is None:
        args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
    load_projects(args.data)