ludusc commited on
Commit
a16f7a6
1 Parent(s): dc026d8

updated model

Browse files
backend/disentangle_concepts.py CHANGED
@@ -1,22 +1,34 @@
1
  import numpy as np
2
  from sklearn.svm import SVC
 
3
  from sklearn.model_selection import train_test_split
4
  import torch
 
5
  import PIL
6
 
7
- def get_separation_space(type_bin, annotations, df):
8
  abstracts = np.array([float(ann) for ann in df[type_bin]])
9
- abstract_idxs = list(np.argsort(abstracts))[:200]
10
- repr_idxs = list(np.argsort(abstracts))[-200:]
11
  X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
12
- X = X.reshape((400, 512))
13
- y = np.array([1]*200 + [0]*200)
14
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
15
- svc = SVC(gamma='auto', kernel='linear')
16
- svc.fit(x_train, y_train)
17
- print(svc.score(x_val, y_val))
18
- imp_features = (np.abs(svc.coef_) > 0.1).sum()
19
- return svc.coef_, imp_features
 
 
 
 
 
 
 
 
 
 
20
 
21
  def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
22
  device = torch.device('cpu')
@@ -55,4 +67,24 @@ def generate_original_image(z, model):
55
  z = torch.from_numpy(z.copy()).to(device)
56
  img = G(z, label, truncation_psi=0.7, noise_mode='random')
57
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
58
- return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from sklearn.svm import SVC
3
+ from sklearn.linear_model import LogisticRegression
4
  from sklearn.model_selection import train_test_split
5
  import torch
6
+ from umap import UMAP
7
  import PIL
8
 
9
+ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
10
  abstracts = np.array([float(ann) for ann in df[type_bin]])
11
+ abstract_idxs = list(np.argsort(abstracts))[:samples]
12
+ repr_idxs = list(np.argsort(abstracts))[-samples:]
13
  X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
14
+ X = X.reshape((2*samples, 512))
15
+ y = np.array([1]*samples + [0]*samples)
16
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
17
+ if method == 'SVM':
18
+ svc = SVC(gamma='auto', kernel='linear', random_state=0, C=C)
19
+ svc.fit(x_train, y_train)
20
+ print('Val performance SVM', svc.score(x_val, y_val))
21
+ imp_features = (np.abs(svc.coef_) > 0.2).sum()
22
+ imp_nodes = np.where(np.abs(svc.coef_) > 0.2)[1]
23
+ return svc.coef_, imp_features, imp_nodes
24
+ elif method == 'LR':
25
+ clf = LogisticRegression(random_state=0, C=C)
26
+ clf.fit(x_train, y_train)
27
+ print('Val performance logistic regression', clf.score(x_val, y_val))
28
+ imp_features = (np.abs(clf.coef_) > 0.2).sum()
29
+ imp_nodes = np.where(np.abs(clf.coef_) > 0.2)[1]
30
+ return clf.coef_, imp_features, imp_nodes
31
+
32
 
33
  def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
34
  device = torch.device('cpu')
 
67
  z = torch.from_numpy(z.copy()).to(device)
68
  img = G(z, label, truncation_psi=0.7, noise_mode='random')
69
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
70
+ return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
71
+
72
+
73
+ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
74
+ important_nodes = []
75
+ vectors = np.zeros((len(concepts), 512))
76
+ for i, conc in enumerate(concepts):
77
+ vec, _, imp_nodes = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
78
+ vectors[i,:] = vec
79
+ important_nodes.append(set(imp_nodes))
80
+
81
+ reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
82
+ n_components=3, # default 2, The dimension of the space to embed into.
83
+ min_dist=0.1, # default 0.1, The effective minimum distance between embedded points.
84
+ spread=2.0, # default 1.0, The effective scale of embedded points. In combination with ``min_dist`` this determines how clustered/clumped the embedded points are.
85
+ random_state=0, # default: None, If int, random_state is the seed used by the random number generator;
86
+ )
87
+
88
+ projection = reducer.fit_transform(vectors)
89
+ nodes_in_common = set.intersection(*important_nodes)
90
+ return vectors, projection, nodes_in_common
data/annotated_files/seeds0000-100000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b3a4fd155fa86df0953ad1cb660d50729189606de307fcee09fd893ba047228
3
+ size 420351795
data/annotated_files/sim_seeds0000-100000.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e501641d051743b0f1eec385bf7cb2d769e3cb15f1fffc08dce6d38c1f2bbf8
3
+ size 14059984
data/model_files/network-snapshot-010600.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a46e8aecd50191b82632b5de7bf3b9e219a59564c54994dd203f016b7a8270e
3
+ size 357344749
nx.html ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <meta charset="utf-8">
4
+
5
+ <script src="lib/bindings/utils.js"></script>
6
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" />
7
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js" integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
8
+
9
+
10
+ <center>
11
+ <h1></h1>
12
+ </center>
13
+
14
+ <!-- <link rel="stylesheet" href="../node_modules/vis/dist/vis.min.css" type="text/css" />
15
+ <script type="text/javascript" src="../node_modules/vis/dist/vis.js"> </script>-->
16
+ <link
17
+ href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/css/bootstrap.min.css"
18
+ rel="stylesheet"
19
+ integrity="sha384-eOJMYsd53ii+scO/bJGFsiCZc+5NDVN2yr8+0RDqr0Ql0h+rP48ckxlpbzKgwra6"
20
+ crossorigin="anonymous"
21
+ />
22
+ <script
23
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/js/bootstrap.bundle.min.js"
24
+ integrity="sha384-JEW9xMcG8R+pH31jmWH6WWP0WintQrMb4s7ZOdauHnUtxwoG2vI5DkLtS3qm9Ekf"
25
+ crossorigin="anonymous"
26
+ ></script>
27
+
28
+
29
+ <center>
30
+ <h1></h1>
31
+ </center>
32
+ <style type="text/css">
33
+
34
+ #mynetwork {
35
+ width: 100%;
36
+ height: 750px;
37
+ background-color: #ffffff;
38
+ border: 1px solid lightgray;
39
+ position: relative;
40
+ float: left;
41
+ }
42
+
43
+
44
+
45
+
46
+
47
+
48
+ </style>
49
+ </head>
50
+
51
+
52
+ <body>
53
+ <div class="card" style="width: 100%">
54
+
55
+
56
+ <div id="mynetwork" class="card-body"></div>
57
+ </div>
58
+
59
+
60
+
61
+
62
+ <script type="text/javascript">
63
+
64
+ // initialize global variables.
65
+ var edges;
66
+ var nodes;
67
+ var allNodes;
68
+ var allEdges;
69
+ var nodeColors;
70
+ var originalNodes;
71
+ var network;
72
+ var container;
73
+ var options, data;
74
+ var filter = {
75
+ item : '',
76
+ property : '',
77
+ value : []
78
+ };
79
+
80
+
81
+
82
+
83
+
84
+ // This method is responsible for drawing the graph, returns the drawn network
85
+ function drawGraph() {
86
+ var container = document.getElementById('mynetwork');
87
+
88
+
89
+
90
+ // parsing and collecting nodes and edges from the python
91
+ nodes = new vis.DataSet([{"color": "#97c2fc", "id": "Op Art", "label": "Op Art", "shape": "dot", "title": "Op Art"}, {"color": "#97c2fc", "id": "Minimalism", "label": "Minimalism", "shape": "dot", "title": "Minimalism"}, {"color": "#97c2fc", "id": "Surrealism", "label": "Surrealism", "shape": "dot", "title": "Surrealism"}, {"color": "#97c2fc", "id": "Baroque", "label": "Baroque", "shape": "dot", "title": "Baroque"}, {"color": "#97c2fc", "id": "Lithography", "label": "Lithography", "shape": "dot", "title": "Lithography"}, {"color": "#97c2fc", "id": "Woodcut", "label": "Woodcut", "shape": "dot", "title": "Woodcut"}, {"color": "#97c2fc", "id": "etching", "label": "etching", "shape": "dot", "title": "etching"}, {"color": "#97c2fc", "id": "Intaglio", "label": "Intaglio", "shape": "dot", "title": "Intaglio"}]);
92
+ edges = new vis.DataSet([{"from": "Op Art", "title": "Op Art to Minimalism similarity 0.432", "to": "Minimalism", "value": 0.432}, {"from": "Op Art", "title": "Op Art to Surrealism similarity -0.086", "to": "Surrealism", "value": -0.086}, {"from": "Op Art", "title": "Op Art to Baroque similarity -0.047", "to": "Baroque", "value": -0.047}, {"from": "Op Art", "title": "Op Art to Lithography similarity 0.054", "to": "Lithography", "value": 0.054}, {"from": "Op Art", "title": "Op Art to Woodcut similarity 0.125", "to": "Woodcut", "value": 0.125}, {"from": "Op Art", "title": "Op Art to etching similarity 0.117", "to": "etching", "value": 0.117}, {"from": "Op Art", "title": "Op Art to Intaglio similarity 0.094", "to": "Intaglio", "value": 0.094}, {"from": "Minimalism", "title": "Minimalism to Surrealism similarity -0.042", "to": "Surrealism", "value": -0.042}, {"from": "Minimalism", "title": "Minimalism to Baroque similarity -0.052", "to": "Baroque", "value": -0.052}, {"from": "Minimalism", "title": "Minimalism to Lithography similarity 0.046", "to": "Lithography", "value": 0.046}, {"from": "Minimalism", "title": "Minimalism to Woodcut similarity 0.069", "to": "Woodcut", "value": 0.069}, {"from": "Minimalism", "title": "Minimalism to etching similarity 0.1", "to": "etching", "value": 0.1}, {"from": "Minimalism", "title": "Minimalism to Intaglio similarity 0.03", "to": "Intaglio", "value": 0.03}, {"from": "Surrealism", "title": "Surrealism to Baroque similarity 0.067", "to": "Baroque", "value": 0.067}, {"from": "Surrealism", "title": "Surrealism to Lithography similarity -0.235", "to": "Lithography", "value": -0.235}, {"from": "Surrealism", "title": "Surrealism to Woodcut similarity -0.16", "to": "Woodcut", "value": -0.16}, {"from": "Surrealism", "title": "Surrealism to etching similarity -0.171", "to": "etching", "value": -0.171}, {"from": "Surrealism", "title": "Surrealism to Intaglio similarity -0.076", "to": "Intaglio", "value": -0.076}, {"from": "Baroque", "title": "Baroque to Lithography similarity -0.125", "to": "Lithography", "value": -0.125}, {"from": "Baroque", "title": "Baroque to Woodcut similarity -0.022", "to": "Woodcut", "value": -0.022}, {"from": "Baroque", "title": "Baroque to etching similarity -0.102", "to": "etching", "value": -0.102}, {"from": "Baroque", "title": "Baroque to Intaglio similarity -0.046", "to": "Intaglio", "value": -0.046}, {"from": "Lithography", "title": "Lithography to Woodcut similarity 0.258", "to": "Woodcut", "value": 0.258}, {"from": "Lithography", "title": "Lithography to etching similarity 0.268", "to": "etching", "value": 0.268}, {"from": "Lithography", "title": "Lithography to Intaglio similarity 0.123", "to": "Intaglio", "value": 0.123}, {"from": "Woodcut", "title": "Woodcut to etching similarity 0.21", "to": "etching", "value": 0.21}, {"from": "Woodcut", "title": "Woodcut to Intaglio similarity 0.209", "to": "Intaglio", "value": 0.209}, {"from": "etching", "title": "etching to Intaglio similarity 0.178", "to": "Intaglio", "value": 0.178}]);
93
+
94
+ nodeColors = {};
95
+ allNodes = nodes.get({ returnType: "Object" });
96
+ for (nodeId in allNodes) {
97
+ nodeColors[nodeId] = allNodes[nodeId].color;
98
+ }
99
+ allEdges = edges.get({ returnType: "Object" });
100
+ // adding nodes and edges to the graph
101
+ data = {nodes: nodes, edges: edges};
102
+
103
+ var options = {
104
+ "configure": {
105
+ "enabled": false
106
+ },
107
+ "edges": {
108
+ "color": {
109
+ "inherit": true
110
+ },
111
+ "smooth": {
112
+ "enabled": true,
113
+ "type": "dynamic"
114
+ }
115
+ },
116
+ "interaction": {
117
+ "dragNodes": true,
118
+ "hideEdgesOnDrag": false,
119
+ "hideNodesOnDrag": false
120
+ },
121
+ "physics": {
122
+ "enabled": true,
123
+ "stabilization": {
124
+ "enabled": true,
125
+ "fit": true,
126
+ "iterations": 1000,
127
+ "onlyDynamicEdges": false,
128
+ "updateInterval": 50
129
+ }
130
+ }
131
+ };
132
+
133
+
134
+
135
+
136
+
137
+
138
+ network = new vis.Network(container, data, options);
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+ return network;
150
+
151
+ }
152
+ drawGraph();
153
+ </script>
154
+ </body>
155
+ </html>
pages/1_Disentanglement.py CHANGED
@@ -9,6 +9,7 @@ from matplotlib.backends.backend_agg import RendererAgg
9
 
10
  from backend.disentangle_concepts import *
11
  import torch_utils
 
12
 
13
  _lock = RendererAgg.lock
14
 
@@ -32,11 +33,11 @@ with st.expander("See more instruction", expanded=False):
32
  st.write(instruction_text)
33
 
34
 
35
- annotations_file = './data/annotated_files/annotations_parallel_seeds0000-10000.pkl'
36
  with open(annotations_file, 'rb') as f:
37
  annotations = pickle.load(f)
38
 
39
- ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-10000.csv')
40
  concepts = './data/concepts.txt'
41
 
42
  with open(concepts) as f:
@@ -117,7 +118,7 @@ with input_col_2:
117
  random_id = st.form_submit_button('Generate a random image')
118
 
119
  if random_id:
120
- image_id = random.randint(0, 10000)
121
  st.session_state.image_id = image_id
122
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
123
 
@@ -135,7 +136,10 @@ with input_col_3:
135
 
136
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
137
 
138
- model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
 
 
 
139
  original_image_vec = annotations['z_vectors'][st.session_state.image_id]
140
  img = generate_original_image(original_image_vec, model)
141
  # input_image = original_image_dict['image']
 
9
 
10
  from backend.disentangle_concepts import *
11
  import torch_utils
12
+ import dnnlib
13
 
14
  _lock = RendererAgg.lock
15
 
 
33
  st.write(instruction_text)
34
 
35
 
36
+ annotations_file = './data/annotated_files/seeds0000-100000.pkl'
37
  with open(annotations_file, 'rb') as f:
38
  annotations = pickle.load(f)
39
 
40
+ ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv')
41
  concepts = './data/concepts.txt'
42
 
43
  with open(concepts) as f:
 
118
  random_id = st.form_submit_button('Generate a random image')
119
 
120
  if random_id:
121
+ image_id = random.randint(0, 100000)
122
  st.session_state.image_id = image_id
123
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
124
 
 
136
 
137
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
138
 
139
+ #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
140
+ with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
141
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
142
+
143
  original_image_vec = annotations['z_vectors'][st.session_state.image_id]
144
  img = generate_original_image(original_image_vec, model)
145
  # input_image = original_image_dict['image']
view_predictions.ipynb ADDED
The diff for this file is too large to render. See raw diff