Johannes Kolbe commited on
Commit
ed6b6d6
1 Parent(s): dd2f594

enable model loading from hf hub

Browse files
.gitignore CHANGED
@@ -20,6 +20,7 @@ __pycache__/
20
  *.zip
21
  events.*
22
 
 
23
  *.pkl
24
  *.h5
25
  *.dat
 
20
  *.zip
21
  events.*
22
 
23
+ /checkpoints/
24
  *.pkl
25
  *.h5
26
  *.dat
.ipynb_checkpoints/model_to_hf_hub-checkpoint.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 15,
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%%\n"
9
+ }
10
+ },
11
+ "outputs": [],
12
+ "source": [
13
+ "import huggingface_hub\n",
14
+ "import utils"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 16,
20
+ "metadata": {
21
+ "pycharm": {
22
+ "name": "#%%\n"
23
+ }
24
+ },
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "application/vnd.jupyter.widget-view+json": {
29
+ "model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
30
+ "version_major": 2,
31
+ "version_minor": 0
32
+ },
33
+ "text/plain": [
34
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
35
+ ]
36
+ },
37
+ "metadata": {},
38
+ "output_type": "display_data"
39
+ }
40
+ ],
41
+ "source": [
42
+ "huggingface_hub.notebook_login()"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 13,
48
+ "metadata": {
49
+ "pycharm": {
50
+ "name": "#%%\n"
51
+ }
52
+ },
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "Building generator for model `stylegan_animeface512` ...\n",
59
+ "Finish building generator.\n",
60
+ "Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
61
+ "Finish loading checkpoint.\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "animeface_model = utils.load_generator('stylegan_animeface512')"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 5,
72
+ "metadata": {
73
+ "pycharm": {
74
+ "name": "#%%\n"
75
+ }
76
+ },
77
+ "outputs": [
78
+ {
79
+ "name": "stderr",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
83
+ ]
84
+ },
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "name": "stderr",
101
+ "output_type": "stream",
102
+ "text": [
103
+ "To https://huggingface.co/johko/stylegan_animeface512\n",
104
+ " 750cd03..2841156 main -> main\n",
105
+ "\n"
106
+ ]
107
+ },
108
+ {
109
+ "data": {
110
+ "text/plain": [
111
+ "'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
112
+ ]
113
+ },
114
+ "execution_count": 5,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 11,
126
+ "metadata": {
127
+ "pycharm": {
128
+ "name": "#%%\n"
129
+ }
130
+ },
131
+ "outputs": [
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "Building generator for model `pggan_celebahq1024` ...\n",
137
+ "Finish building generator.\n",
138
+ "Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
139
+ "Finish loading checkpoint.\n"
140
+ ]
141
+ }
142
+ ],
143
+ "source": [
144
+ "celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 7,
150
+ "metadata": {
151
+ "pycharm": {
152
+ "name": "#%%\n"
153
+ }
154
+ },
155
+ "outputs": [
156
+ {
157
+ "name": "stderr",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
161
+ ]
162
+ },
163
+ {
164
+ "data": {
165
+ "application/vnd.jupyter.widget-view+json": {
166
+ "model_id": "ef4086b23a654b079bd6a3678140c50d",
167
+ "version_major": 2,
168
+ "version_minor": 0
169
+ },
170
+ "text/plain": [
171
+ "Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
172
+ ]
173
+ },
174
+ "metadata": {},
175
+ "output_type": "display_data"
176
+ },
177
+ {
178
+ "name": "stderr",
179
+ "output_type": "stream",
180
+ "text": [
181
+ "To https://huggingface.co/johko/pggan-celebahq-1024\n",
182
+ " 780695e..278449f main -> main\n",
183
+ "\n"
184
+ ]
185
+ },
186
+ {
187
+ "data": {
188
+ "text/plain": [
189
+ "'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
190
+ ]
191
+ },
192
+ "execution_count": 7,
193
+ "metadata": {},
194
+ "output_type": "execute_result"
195
+ }
196
+ ],
197
+ "source": [
198
+ "celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 17,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "Building generator for model `stylegan_car512` ...\n",
211
+ "Finish building generator.\n",
212
+ "Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
213
+ "Finish loading checkpoint.\n"
214
+ ]
215
+ }
216
+ ],
217
+ "source": [
218
+ "cars_model = utils.load_generator(\"stylegan_car512\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "cars_model.push_to_hub(\"johko/stylegan_car512\")"
228
+ ]
229
+ }
230
+ ],
231
+ "metadata": {
232
+ "interpreter": {
233
+ "hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
234
+ },
235
+ "kernelspec": {
236
+ "display_name": "Python 3",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.9.9"
251
+ }
252
+ },
253
+ "nbformat": 4,
254
+ "nbformat_minor": 2
255
+ }
app.py CHANGED
@@ -16,7 +16,7 @@ from utils import factorize_weight
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
  def get_model(model_name):
18
  """Gets model by name."""
19
- return load_generator(model_name)
20
 
21
 
22
  @st.cache(allow_output_mutation=True, show_spinner=False)
@@ -72,7 +72,7 @@ layer_idx = st.sidebar.selectbox(
72
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
73
 
74
  num_semantics = st.sidebar.number_input(
75
- 'Number of semantics', value=10, min_value=0, max_value=None, step=1)
76
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
77
  if gan_type == 'pggan':
78
  max_step = 5.0
 
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
  def get_model(model_name):
18
  """Gets model by name."""
19
+ return load_generator(model_name, from_hf_hub=True)
20
 
21
 
22
  @st.cache(allow_output_mutation=True, show_spinner=False)
 
72
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
73
 
74
  num_semantics = st.sidebar.number_input(
75
+ 'Number of semantics', value=5, min_value=0, max_value=None, step=1)
76
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
77
  if gan_type == 'pggan':
78
  max_step = 5.0
interface.py CHANGED
@@ -16,7 +16,7 @@ from utils import factorize_weight
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
  def get_model(model_name):
18
  """Gets model by name."""
19
- return load_generator(model_name)
20
 
21
 
22
  @st.cache(allow_output_mutation=True, show_spinner=False)
@@ -27,7 +27,7 @@ def factorize_model(model, layer_idx):
27
 
28
  def sample(model, gan_type, num=1):
29
  """Samples latent codes."""
30
- codes = torch.randn(num, model.z_space_dim).cuda()
31
  if gan_type == 'pggan':
32
  codes = model.layer0.pixel_norm(codes)
33
  elif gan_type == 'stylegan':
@@ -63,8 +63,7 @@ def main():
63
 
64
  model_name = st.sidebar.selectbox(
65
  'Model to Interpret',
66
- ['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256'
67
- ])
68
 
69
  model = get_model(model_name)
70
  gan_type = parse_gan_type(model)
@@ -74,7 +73,7 @@ def main():
74
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
75
 
76
  num_semantics = st.sidebar.number_input(
77
- 'Number of semantics', value=10, min_value=0, max_value=None, step=1)
78
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
79
  if gan_type == 'pggan':
80
  max_step = 5.0
 
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
  def get_model(model_name):
18
  """Gets model by name."""
19
+ return load_generator(model_name, from_hf_hub=True)
20
 
21
 
22
  @st.cache(allow_output_mutation=True, show_spinner=False)
 
27
 
28
  def sample(model, gan_type, num=1):
29
  """Samples latent codes."""
30
+ codes = torch.randn(num, model.z_space_dim)
31
  if gan_type == 'pggan':
32
  codes = model.layer0.pixel_norm(codes)
33
  elif gan_type == 'stylegan':
 
63
 
64
  model_name = st.sidebar.selectbox(
65
  'Model to Interpret',
66
+ ['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',])
 
67
 
68
  model = get_model(model_name)
69
  gan_type = parse_gan_type(model)
 
73
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
74
 
75
  num_semantics = st.sidebar.number_input(
76
+ 'Number of semantics', value=5, min_value=0, max_value=None, step=1)
77
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
78
  if gan_type == 'pggan':
79
  max_step = 5.0
model_to_hf_hub.ipynb ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 15,
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%%\n"
9
+ }
10
+ },
11
+ "outputs": [],
12
+ "source": [
13
+ "import huggingface_hub\n",
14
+ "import utils"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 16,
20
+ "metadata": {
21
+ "pycharm": {
22
+ "name": "#%%\n"
23
+ }
24
+ },
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "application/vnd.jupyter.widget-view+json": {
29
+ "model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
30
+ "version_major": 2,
31
+ "version_minor": 0
32
+ },
33
+ "text/plain": [
34
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
35
+ ]
36
+ },
37
+ "metadata": {},
38
+ "output_type": "display_data"
39
+ }
40
+ ],
41
+ "source": [
42
+ "huggingface_hub.notebook_login()"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 13,
48
+ "metadata": {
49
+ "pycharm": {
50
+ "name": "#%%\n"
51
+ }
52
+ },
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "Building generator for model `stylegan_animeface512` ...\n",
59
+ "Finish building generator.\n",
60
+ "Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
61
+ "Finish loading checkpoint.\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "animeface_model = utils.load_generator('stylegan_animeface512')"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 5,
72
+ "metadata": {
73
+ "pycharm": {
74
+ "name": "#%%\n"
75
+ }
76
+ },
77
+ "outputs": [
78
+ {
79
+ "name": "stderr",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
83
+ ]
84
+ },
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "name": "stderr",
101
+ "output_type": "stream",
102
+ "text": [
103
+ "To https://huggingface.co/johko/stylegan_animeface512\n",
104
+ " 750cd03..2841156 main -> main\n",
105
+ "\n"
106
+ ]
107
+ },
108
+ {
109
+ "data": {
110
+ "text/plain": [
111
+ "'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
112
+ ]
113
+ },
114
+ "execution_count": 5,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 11,
126
+ "metadata": {
127
+ "pycharm": {
128
+ "name": "#%%\n"
129
+ }
130
+ },
131
+ "outputs": [
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "Building generator for model `pggan_celebahq1024` ...\n",
137
+ "Finish building generator.\n",
138
+ "Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
139
+ "Finish loading checkpoint.\n"
140
+ ]
141
+ }
142
+ ],
143
+ "source": [
144
+ "celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 7,
150
+ "metadata": {
151
+ "pycharm": {
152
+ "name": "#%%\n"
153
+ }
154
+ },
155
+ "outputs": [
156
+ {
157
+ "name": "stderr",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
161
+ ]
162
+ },
163
+ {
164
+ "data": {
165
+ "application/vnd.jupyter.widget-view+json": {
166
+ "model_id": "ef4086b23a654b079bd6a3678140c50d",
167
+ "version_major": 2,
168
+ "version_minor": 0
169
+ },
170
+ "text/plain": [
171
+ "Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
172
+ ]
173
+ },
174
+ "metadata": {},
175
+ "output_type": "display_data"
176
+ },
177
+ {
178
+ "name": "stderr",
179
+ "output_type": "stream",
180
+ "text": [
181
+ "To https://huggingface.co/johko/pggan-celebahq-1024\n",
182
+ " 780695e..278449f main -> main\n",
183
+ "\n"
184
+ ]
185
+ },
186
+ {
187
+ "data": {
188
+ "text/plain": [
189
+ "'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
190
+ ]
191
+ },
192
+ "execution_count": 7,
193
+ "metadata": {},
194
+ "output_type": "execute_result"
195
+ }
196
+ ],
197
+ "source": [
198
+ "celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 17,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "Building generator for model `stylegan_car512` ...\n",
211
+ "Finish building generator.\n",
212
+ "Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
213
+ "Finish loading checkpoint.\n"
214
+ ]
215
+ }
216
+ ],
217
+ "source": [
218
+ "cars_model = utils.load_generator(\"stylegan_car512\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 21,
224
+ "metadata": {},
225
+ "outputs": [
226
+ {
227
+ "name": "stdout",
228
+ "output_type": "stream",
229
+ "text": [
230
+ "Building generator for model `stylegan_cat256` ...\n",
231
+ "Finish building generator.\n",
232
+ "Loading checkpoint from `checkpoints/stylegan_cat256.pth` ...\n",
233
+ "Finish loading checkpoint.\n"
234
+ ]
235
+ }
236
+ ],
237
+ "source": [
238
+ "cats_model = utils.load_generator(\"stylegan_cat256\")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {},
245
+ "outputs": [
246
+ {
247
+ "name": "stderr",
248
+ "output_type": "stream",
249
+ "text": [
250
+ "Cloning https://huggingface.co/johko/stylegan_cat256 into local empty directory.\n"
251
+ ]
252
+ },
253
+ {
254
+ "data": {
255
+ "application/vnd.jupyter.widget-view+json": {
256
+ "model_id": "651e9bff9c9f4555814171195e36d4d3",
257
+ "version_major": 2,
258
+ "version_minor": 0
259
+ },
260
+ "text/plain": [
261
+ "Upload file pytorch_model.bin: 0%| | 32.0k/100M [00:00<?, ?B/s]"
262
+ ]
263
+ },
264
+ "metadata": {},
265
+ "output_type": "display_data"
266
+ }
267
+ ],
268
+ "source": [
269
+ "cats_model.push_to_hub(\"johko/stylegan_cat256\")"
270
+ ]
271
+ }
272
+ ],
273
+ "metadata": {
274
+ "interpreter": {
275
+ "hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
276
+ },
277
+ "kernelspec": {
278
+ "display_name": "Python 3",
279
+ "language": "python",
280
+ "name": "python3"
281
+ },
282
+ "language_info": {
283
+ "codemirror_mode": {
284
+ "name": "ipython",
285
+ "version": 3
286
+ },
287
+ "file_extension": ".py",
288
+ "mimetype": "text/x-python",
289
+ "name": "python",
290
+ "nbconvert_exporter": "python",
291
+ "pygments_lexer": "ipython3",
292
+ "version": "3.9.9"
293
+ }
294
+ },
295
+ "nbformat": 4,
296
+ "nbformat_minor": 2
297
+ }
models/model_zoo.py CHANGED
@@ -9,6 +9,7 @@ MODEL_ZOO = {
9
  gan_type='pggan',
10
  resolution=1024,
11
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
 
12
  ),
13
  'pggan_bedroom256': dict(
14
  gan_type='pggan',
@@ -181,11 +182,13 @@ MODEL_ZOO = {
181
  gan_type='stylegan',
182
  resolution=256,
183
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
 
184
  ),
185
  'stylegan_car512': dict(
186
  gan_type='stylegan',
187
  resolution=512,
188
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
 
189
  ),
190
 
191
  # StyleGAN ours.
@@ -260,6 +263,7 @@ MODEL_ZOO = {
260
  gan_type='stylegan',
261
  resolution=512,
262
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
 
263
  ),
264
  'stylegan_animeportrait512': dict(
265
  gan_type='stylegan',
@@ -296,15 +300,8 @@ MODEL_ZOO = {
296
  'stylegan2_car512': dict(
297
  gan_type='stylegan2',
298
  resolution=512,
299
- url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1',
300
  ),
301
-
302
- #huggingface models
303
- 'akhaliq/OneshotCLIP-stylegan2-ffhq' : dict(
304
- gan_type='stylegan2',
305
- resolution=512,
306
- url='akhaliq/OneshotCLIP-stylegan2-ffhq',
307
- )
308
  }
309
 
310
  # pylint: enable=line-too-long
 
9
  gan_type='pggan',
10
  resolution=1024,
11
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
12
+ hf_hub_repo='huggan/pggan-celebahq-1024'
13
  ),
14
  'pggan_bedroom256': dict(
15
  gan_type='pggan',
 
182
  gan_type='stylegan',
183
  resolution=256,
184
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
185
+ hf_hub_repo="huggan/stylegan_cat256"
186
  ),
187
  'stylegan_car512': dict(
188
  gan_type='stylegan',
189
  resolution=512,
190
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
191
+ hf_hub_repo="huggan/stylegan_car512"
192
  ),
193
 
194
  # StyleGAN ours.
 
263
  gan_type='stylegan',
264
  resolution=512,
265
  url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
266
+ hf_hub_repo='huggan/stylegan_animeface512'
267
  ),
268
  'stylegan_animeportrait512': dict(
269
  gan_type='stylegan',
 
300
  'stylegan2_car512': dict(
301
  gan_type='stylegan2',
302
  resolution=512,
303
+ url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1'
304
  ),
 
 
 
 
 
 
 
305
  }
306
 
307
  # pylint: enable=line-too-long
models/pggan_generator.py CHANGED
@@ -6,6 +6,7 @@ Paper: https://arxiv.org/pdf/1710.10196.pdf
6
  Official TensorFlow implementation:
7
  https://github.com/tkarras/progressive_growing_of_gans
8
  """
 
9
 
10
  import numpy as np
11
 
@@ -13,6 +14,8 @@ import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
 
 
 
16
  __all__ = ['PGGANGenerator']
17
 
18
  # Resolutions allowed.
@@ -25,7 +28,7 @@ _INIT_RES = 4
25
  _WSCALE_GAIN = np.sqrt(2.0)
26
 
27
 
28
- class PGGANGenerator(nn.Module):
29
  """Defines the generator network in PGGAN.
30
 
31
  NOTE: The synthesized images are with `RGB` channel order and pixel range
@@ -57,7 +60,8 @@ class PGGANGenerator(nn.Module):
57
  fused_scale=False,
58
  use_wscale=True,
59
  fmaps_base=16 << 10,
60
- fmaps_max=512):
 
61
  """Initializes with basic settings.
62
 
63
  Raises:
@@ -81,6 +85,8 @@ class PGGANGenerator(nn.Module):
81
  self.use_wscale = use_wscale
82
  self.fmaps_base = fmaps_base
83
  self.fmaps_max = fmaps_max
 
 
84
 
85
  # Number of convolutional layers.
86
  self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
@@ -202,6 +208,46 @@ class PGGANGenerator(nn.Module):
202
  }
203
  return results
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  class PixelNormLayer(nn.Module):
207
  """Implements pixel-wise feature vector normalization layer."""
 
6
  Official TensorFlow implementation:
7
  https://github.com/tkarras/progressive_growing_of_gans
8
  """
9
+ import os
10
 
11
  import numpy as np
12
 
 
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
+ from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
18
+
19
  __all__ = ['PGGANGenerator']
20
 
21
  # Resolutions allowed.
 
28
  _WSCALE_GAIN = np.sqrt(2.0)
29
 
30
 
31
+ class PGGANGenerator(nn.Module, PyTorchModelHubMixin):
32
  """Defines the generator network in PGGAN.
33
 
34
  NOTE: The synthesized images are with `RGB` channel order and pixel range
 
60
  fused_scale=False,
61
  use_wscale=True,
62
  fmaps_base=16 << 10,
63
+ fmaps_max=512,
64
+ **kwargs):
65
  """Initializes with basic settings.
66
 
67
  Raises:
 
85
  self.use_wscale = use_wscale
86
  self.fmaps_base = fmaps_base
87
  self.fmaps_max = fmaps_max
88
+
89
+ self.config = kwargs.pop("config", None)
90
 
91
  # Number of convolutional layers.
92
  self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
 
208
  }
209
  return results
210
 
211
+ @classmethod
212
+ def _from_pretrained(
213
+ cls,
214
+ model_id,
215
+ revision,
216
+ cache_dir,
217
+ force_download,
218
+ proxies,
219
+ resume_download,
220
+ local_files_only,
221
+ use_auth_token,
222
+ map_location="cpu",
223
+ strict=False,
224
+ **model_kwargs,
225
+ ):
226
+ """
227
+ Overwrite this method in case you wish to initialize your model in a
228
+ different way.
229
+ """
230
+ map_location = torch.device(map_location)
231
+
232
+ if os.path.isdir(model_id):
233
+ print("Loading weights from local directory")
234
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
235
+ else:
236
+ model_file = hf_hub_download(
237
+ repo_id=model_id,
238
+ filename=PYTORCH_WEIGHTS_NAME,
239
+ revision=revision,
240
+ cache_dir=cache_dir,
241
+ force_download=force_download,
242
+ proxies=proxies,
243
+ resume_download=resume_download,
244
+ use_auth_token=use_auth_token,
245
+ local_files_only=local_files_only,
246
+ )
247
+
248
+ pretrained = torch.load(model_file, map_location=map_location)
249
+ return pretrained
250
+
251
 
252
  class PixelNormLayer(nn.Module):
253
  """Implements pixel-wise feature vector normalization layer."""
models/stylegan2_generator.py CHANGED
@@ -9,12 +9,14 @@ Paper: https://arxiv.org/pdf/1912.04958.pdf
9
 
10
  Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
  """
 
12
 
13
  import numpy as np
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
 
18
 
19
  from .sync_op import all_gather
20
 
@@ -33,7 +35,7 @@ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
33
  _WSCALE_GAIN = 1.0
34
 
35
 
36
- class StyleGAN2Generator(nn.Module):
37
  """Defines the generator network in StyleGAN2.
38
 
39
  NOTE: The synthesized images are with `RGB` channel order and pixel range
@@ -88,7 +90,8 @@ class StyleGAN2Generator(nn.Module):
88
  demodulate=True,
89
  use_wscale=True,
90
  fmaps_base=32 << 10,
91
- fmaps_max=512):
 
92
  """Initializes with basic settings.
93
 
94
  Raises:
@@ -195,6 +198,45 @@ class StyleGAN2Generator(nn.Module):
195
 
196
  return {**mapping_results, **synthesis_results}
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  class MappingModule(nn.Module):
200
  """Implements the latent space mapping module.
 
9
 
10
  Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
  """
12
+ import os
13
 
14
  import numpy as np
15
 
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
+ from huggingface_hub import PYTORCH_WEIGHTS_NAME, hf_hub_download, PyTorchModelHubMixin
20
 
21
  from .sync_op import all_gather
22
 
 
35
  _WSCALE_GAIN = 1.0
36
 
37
 
38
+ class StyleGAN2Generator(nn.Module, PyTorchModelHubMixin):
39
  """Defines the generator network in StyleGAN2.
40
 
41
  NOTE: The synthesized images are with `RGB` channel order and pixel range
 
90
  demodulate=True,
91
  use_wscale=True,
92
  fmaps_base=32 << 10,
93
+ fmaps_max=512,
94
+ **kwargs):
95
  """Initializes with basic settings.
96
 
97
  Raises:
 
198
 
199
  return {**mapping_results, **synthesis_results}
200
 
201
+ @classmethod
202
+ def _from_pretrained(
203
+ cls,
204
+ model_id,
205
+ revision,
206
+ cache_dir,
207
+ force_download,
208
+ proxies,
209
+ resume_download,
210
+ local_files_only,
211
+ use_auth_token,
212
+ map_location="cpu",
213
+ strict=False,
214
+ **model_kwargs,
215
+ ):
216
+ """
217
+ Overwrite this method in case you wish to initialize your model in a
218
+ different way.
219
+ """
220
+ map_location = torch.device(map_location)
221
+
222
+ if os.path.isdir(model_id):
223
+ print("Loading weights from local directory")
224
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
225
+ else:
226
+ model_file = hf_hub_download(
227
+ repo_id=model_id,
228
+ filename="stylegan2-ffhq-config-f.pt",
229
+ revision=revision,
230
+ cache_dir=cache_dir,
231
+ force_download=force_download,
232
+ proxies=proxies,
233
+ resume_download=resume_download,
234
+ use_auth_token=use_auth_token,
235
+ local_files_only=local_files_only,
236
+ )
237
+
238
+ pretrained = torch.load(model_file, map_location=map_location)
239
+ return pretrained
240
 
241
  class MappingModule(nn.Module):
242
  """Implements the latent space mapping module.
models/stylegan_generator.py CHANGED
@@ -5,6 +5,7 @@ Paper: https://arxiv.org/pdf/1812.04948.pdf
5
 
6
  Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
  """
 
8
 
9
  import numpy as np
10
 
@@ -14,6 +15,8 @@ import torch.nn.functional as F
14
 
15
  from .sync_op import all_gather
16
 
 
 
17
  __all__ = ['StyleGANGenerator']
18
 
19
  # Resolutions allowed.
@@ -33,7 +36,7 @@ _WSCALE_GAIN = np.sqrt(2.0)
33
  _STYLEMOD_WSCALE_GAIN = 1.0
34
 
35
 
36
- class StyleGANGenerator(nn.Module):
37
  """Defines the generator network in StyleGAN.
38
 
39
  NOTE: The synthesized images are with `RGB` channel order and pixel range
@@ -83,7 +86,8 @@ class StyleGANGenerator(nn.Module):
83
  fused_scale='auto',
84
  use_wscale=True,
85
  fmaps_base=16 << 10,
86
- fmaps_max=512):
 
87
  """Initializes with basic settings.
88
 
89
  Raises:
@@ -115,6 +119,9 @@ class StyleGANGenerator(nn.Module):
115
  self.use_wscale = use_wscale
116
  self.fmaps_base = fmaps_base
117
  self.fmaps_max = fmaps_max
 
 
 
118
 
119
  self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
120
 
@@ -188,6 +195,46 @@ class StyleGANGenerator(nn.Module):
188
 
189
  return {**mapping_results, **synthesis_results}
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  class MappingModule(nn.Module):
193
  """Implements the latent space mapping module.
 
5
 
6
  Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
  """
8
+ import os
9
 
10
  import numpy as np
11
 
 
15
 
16
  from .sync_op import all_gather
17
 
18
+ from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
19
+
20
  __all__ = ['StyleGANGenerator']
21
 
22
  # Resolutions allowed.
 
36
  _STYLEMOD_WSCALE_GAIN = 1.0
37
 
38
 
39
+ class StyleGANGenerator(nn.Module, PyTorchModelHubMixin):
40
  """Defines the generator network in StyleGAN.
41
 
42
  NOTE: The synthesized images are with `RGB` channel order and pixel range
 
86
  fused_scale='auto',
87
  use_wscale=True,
88
  fmaps_base=16 << 10,
89
+ fmaps_max=512,
90
+ **kwargs):
91
  """Initializes with basic settings.
92
 
93
  Raises:
 
119
  self.use_wscale = use_wscale
120
  self.fmaps_base = fmaps_base
121
  self.fmaps_max = fmaps_max
122
+
123
+ self.config = kwargs.pop("config", None)
124
+
125
 
126
  self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
127
 
 
195
 
196
  return {**mapping_results, **synthesis_results}
197
 
198
+ @classmethod
199
+ def _from_pretrained(
200
+ cls,
201
+ model_id,
202
+ revision,
203
+ cache_dir,
204
+ force_download,
205
+ proxies,
206
+ resume_download,
207
+ local_files_only,
208
+ use_auth_token,
209
+ map_location="cpu",
210
+ strict=False,
211
+ **model_kwargs,
212
+ ):
213
+ """
214
+ Overwrite this method in case you wish to initialize your model in a
215
+ different way.
216
+ """
217
+ map_location = torch.device(map_location)
218
+
219
+ if os.path.isdir(model_id):
220
+ print("Loading weights from local directory")
221
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
222
+ else:
223
+ model_file = hf_hub_download(
224
+ repo_id=model_id,
225
+ filename=PYTORCH_WEIGHTS_NAME,
226
+ revision=revision,
227
+ cache_dir=cache_dir,
228
+ force_download=force_download,
229
+ proxies=proxies,
230
+ resume_download=resume_download,
231
+ use_auth_token=use_auth_token,
232
+ local_files_only=local_files_only,
233
+ )
234
+
235
+ pretrained = torch.load(model_file, map_location=map_location)
236
+ return pretrained
237
+
238
 
239
  class MappingModule(nn.Module):
240
  """Implements the latent space mapping module.
utils.py CHANGED
@@ -50,7 +50,7 @@ def postprocess(images, min_val=-1.0, max_val=1.0):
50
  return images
51
 
52
 
53
- def load_generator(model_name):
54
  """Loads pre-trained generator.
55
 
56
  Args:
@@ -74,19 +74,25 @@ def load_generator(model_name):
74
  generator = build_generator(**model_config)
75
  print(f'Finish building generator.')
76
 
77
- # Load pre-trained weights.
78
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
79
- checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
80
- print(f'Loading checkpoint from `{checkpoint_path}` ...')
81
- if not os.path.exists(checkpoint_path):
82
- print(f' Downloading checkpoint from `{url}` ...')
83
- subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
84
- print(f' Finish downloading checkpoint.')
85
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
86
- if 'generator_smooth' in checkpoint:
87
- generator.load_state_dict(checkpoint['generator_smooth'])
88
  else:
89
- generator.load_state_dict(checkpoint['generator'])
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  #generator = generator.cuda()
91
  generator.eval()
92
  print(f'Finish loading checkpoint.')
 
50
  return images
51
 
52
 
53
+ def load_generator(model_name, from_hf_hub=False):
54
  """Loads pre-trained generator.
55
 
56
  Args:
 
74
  generator = build_generator(**model_config)
75
  print(f'Finish building generator.')
76
 
77
+ if from_hf_hub and "hf_hub_repo" in model_config.keys():
78
+ checkpoint = generator.from_pretrained(model_config["hf_hub_repo"])
79
+ generator.load_state_dict(checkpoint)
80
+ print("loaded from hf_hub")
 
 
 
 
 
 
 
81
  else:
82
+ # Load pre-trained weights.
83
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
84
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
85
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
86
+ if not os.path.exists(checkpoint_path):
87
+ print(f' Downloading checkpoint from `{url}` ...')
88
+ subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
89
+ print(f' Finish downloading checkpoint.')
90
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
91
+
92
+ if 'generator_smooth' in checkpoint:
93
+ generator.load_state_dict(checkpoint['generator_smooth'])
94
+ else:
95
+ generator.load_state_dict(checkpoint['generator'])
96
  #generator = generator.cuda()
97
  generator.eval()
98
  print(f'Finish loading checkpoint.')