vitaliykinakh commited on
Commit
5a67fb4
β€’
1 Parent(s): 1085c64

Implement interpolation between labels

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. src/app/interpolate_labels.py +141 -0
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
 
3
  # Custom imports
4
  from src.app import MultiPage
5
- from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models
6
 
7
  # Create an instance of the app
8
  app = MultiPage()
@@ -15,6 +15,7 @@ app.add_page('Compare models', compare_models.app)
15
  app.add_page('Explore BigGAN', explore_biggan.app)
16
  app.add_page('Explore cVAE', explore_cvae.app)
17
  app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
 
18
 
19
  # The main app
20
  app.run()
 
2
 
3
  # Custom imports
4
  from src.app import MultiPage
5
+ from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models, interpolate_labels
6
 
7
  # Create an instance of the app
8
  app = MultiPage()
 
15
  app.add_page('Explore BigGAN', explore_biggan.app)
16
  app.add_page('Explore cVAE', explore_cvae.app)
17
  app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
18
+ app.add_page('Interpolate labels', interpolate_labels.app)
19
 
20
  # The main app
21
  app.run()
src/app/interpolate_labels.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import math
3
+
4
+ import numpy as np
5
+ import streamlit as st
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import src.app.params as params
11
+ from src.models import ConditionalGenerator as InfoSCC_GAN
12
+ from src.models.big.BigGAN2 import Generator as BigGAN2Generator
13
+ from src.models import ConditionalDecoder as cVAE
14
+ from src.data import get_labels_train
15
+ from src.utils import download_file, sample_labels
16
+
17
+
18
+ device = params.device
19
+ size = params.size
20
+ n_layers = int(math.log2(size) - 2)
21
+ bs = 12
22
+ lin_space = torch.linspace(0, 1, bs).unsqueeze(1)
23
+ captions = [f'label_a * {(1 - x):.02f} + label_b * {x:.02f}' for x in lin_space.squeeze().numpy()]
24
+
25
+
26
+ @st.cache(allow_output_mutation=True)
27
+ def load_model(model_type: str):
28
+
29
+ print(f'Loading model: {model_type}')
30
+ if model_type == 'InfoSCC-GAN':
31
+ g = InfoSCC_GAN(size=params.size,
32
+ y_size=params.shape_label,
33
+ z_size=params.noise_dim)
34
+
35
+ if not Path(params.path_infoscc_gan).exists():
36
+ download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
37
+
38
+ ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
39
+ g.load_state_dict(ckpt['g_ema'])
40
+ elif model_type == 'BigGAN':
41
+ g = BigGAN2Generator()
42
+
43
+ if not Path(params.path_biggan).exists():
44
+ download_file(params.drive_id_biggan, params.path_biggan)
45
+
46
+ ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
47
+ g.load_state_dict(ckpt)
48
+ elif model_type == 'cVAE':
49
+ g = cVAE()
50
+
51
+ if not Path(params.path_cvae).exists():
52
+ download_file(params.drive_id_cvae, params.path_cvae)
53
+
54
+ ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
55
+ g.load_state_dict(ckpt)
56
+ else:
57
+ raise ValueError('Unsupported model')
58
+ g = g.eval().to(device=params.device)
59
+ return g
60
+
61
+
62
+ @st.cache
63
+ def get_labels() -> torch.Tensor:
64
+ path_labels = params.path_labels
65
+
66
+ if not Path(path_labels).exists():
67
+ download_file(params.drive_id_labels, path_labels)
68
+
69
+ labels_train = get_labels_train(path_labels)
70
+ return labels_train
71
+
72
+
73
+ def get_eps(n: int) -> torch.Tensor:
74
+ eps = torch.randn((n, params.dim_z), device=device)
75
+ return eps
76
+
77
+
78
+ def app():
79
+
80
+ global lin_space, captions
81
+
82
+ st.title('Interpolate Labels')
83
+ st.markdown('This app allows the generation of the images with the labels that are interpolated between two labels.')
84
+ st.markdown('In each row there are images generated with the same interpolated label by one of the models')
85
+
86
+ biggan = load_model('BigGAN')
87
+ infoscc_gan = load_model('InfoSCC-GAN')
88
+ cvae = load_model('cVAE')
89
+ labels_train = get_labels()
90
+
91
+ # ==================== Labels ==============================================
92
+ label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
93
+ label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
94
+ label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
95
+
96
+ sample_label = st.button('Sample label')
97
+ if sample_label:
98
+ label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
99
+ label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
100
+ label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
101
+ # ==================== Labels ==============================================
102
+
103
+ # ==================== Noise ==============================================
104
+ eps = get_eps(1).repeat(bs, 1)
105
+ eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
106
+
107
+ zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
108
+ zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
109
+
110
+ st.subheader('Noise')
111
+ st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
112
+ change_eps = st.button('Change eps')
113
+ if change_eps:
114
+ eps = get_eps(1).repeat(bs, 1)
115
+ eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
116
+ # ==================== Noise ==============================================
117
+
118
+ with torch.no_grad():
119
+ imgs_biggan = biggan(eps, label_interpolated).squeeze(0).cpu()
120
+ imgs_infoscc = infoscc_gan(label_interpolated, eps_infoscc, zs_torch).squeeze(0).cpu()
121
+ imgs_cvae = cvae(eps, label_interpolated).squeeze(0).cpu()
122
+
123
+ if params.upsample:
124
+ imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
125
+ imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
126
+ imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')
127
+
128
+ imgs_biggan = torch.clip(imgs_biggan, 0, 1)
129
+ imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
130
+ imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
131
+ imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
132
+
133
+ c1, c2, c3 = st.columns(3)
134
+ c1.header('BigGAN')
135
+ c1.image(imgs_biggan, use_column_width=True, caption=captions)
136
+
137
+ c2.header('InfoSCC-GAN')
138
+ c2.image(imgs_infoscc, use_column_width=True, caption=captions)
139
+
140
+ c3.header('cVAE')
141
+ c3.image(imgs_cvae, use_column_width=True, caption=captions)