pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
# ON THE RELATIONSHIP BETWEEN DISENTANGLEMENT
## AND MULTI-TASK LEARNING
**Anonymous authors**
Paper under double-blind review
ABSTRACT
One of the main arguments behind studying disentangled representations is the
assumption that they can be easily reused in different tasks. At the same time
finding a joint, adaptable representation of data is one of the key challenges in the
multi-task learning setting. In this paper, we take a closer look at the relationship between disentanglement and multi-task learning based on hard parameter
sharing. We perform a thorough empirical study of the representations obtained
by neural networks trained on automatically generated supervised tasks. Using a
set of standard metrics we show that disentanglement appears naturally during the
process of multi-task neural network training.
1 INTRODUCTION
Disentangled representations have recently become an important topic in the deep learning community (Eastwood & Williams, 2018; Locatello et al., 2019a; Ma et al., 2019; Sanchez et al., 2019; Do
& Tran, 2020). The main assumption in this problem is that the data encountered in the real world
is generated by few independent and explanatory factors of variation. It is commonly accepted that
such representations are not only more interpretable and robust but also perform better in tasks related to transfer learning and one-shot learning (Bengio, 2013; Lake et al., 2017; Sch¨olkopf et al.,
2012; Locatello et al., 2019c).
Intuitively, a disentangled representation encompasses all the factors of variation and as such can
be used for various tasks based on the same input space. On the other hand, non-disentangled
representations, such as those learned by vanilla neural networks, might focus only on one or a
few factors of variations that are relevant for the current task, while discarding the rest. Such a
representation may fail when encountering different tasks that rely on distant aspects of variation
which have not been captured.
Exploiting prevalent features and differences across tasks is also the paradigm of multi-task learning. In a standard formulation of a multi-task setting, a model is given one input and has to return
predictions for multiple tasks at once. The neural network might be therefore implicitly regularized
to capture more factors of variation than a network that learns only a single task. Based on this
intuition, we hypothesize that disentanglement is likely to occur in the latent representations in this
type of problem.
This paper aims to test this hypothesis empirically. We investigate whether the use of disentangled
representations improves the performance of a multi-task neural network and whether disentanglement itself is achieved naturally during the training process in such a setting.
Our key contributions are:
- Construction of synthetic datasets that allow studying the relationship between multi-task
and disentanglement learning.
- Study of the effect of multi-task learning with hard parameter sharing on the level of disentanglement obtained in the latent representation of the model.
- Analysis of the informativeness of the latent representation obtained in the single- and
multi-task training.
- Inspection of the effect of disentangled representations on the performance of a multi-task
model.
-----
We verify our hypotheses by training multiple models in single- and multi-task settings and investigating the level of disentanglement achieved in their latent representations. In our experiments, we
find that in a hard-parameter sharing scenario multi-task learning indeed seems to encourage disentanglement. However, it is inconclusive whether disentangled representations have a clear positive
impact on the models performance, as the obtained by us results in this matter vary for different
datasets.
2 RELATED WORK
2.1 DISENTANGLEMENT
Over the recent years, many methods that directly encourage disentanglement have been proposed.
This includes algorithms based on variational and Wasserstein auto-encoders (Kim & Mnih, 2018;
Higgins et al., 2017; Kumar et al., 2017; Brakel & Bengio, 2017; Spurek et al., 2020), flow networks
(Dinh et al., 2014; Sorrenson et al., 2020) or generative adversarial networks (Chen et al., 2016).
The main interest behind disentanglement learning lays in the assumption that such transformation
unravels the semantically meaningful factors of variation present in the observations and thus it is
desired in training deep learning models. In particular, disentanglement is believed to allow for
informative compression of the data that results in a structural, interpretable representation, which is
easily adaptable for new tasks (Bengio, 2013; Lake et al., 2017; Schmidhuber, 1992; Lipton, 2018).
Several of these properties have been experimentally proven in applications in many domains, including video processing tasks (Hsieh et al., 2018), recommendation systems (Ma et al., 2019) or abstract reasoning (Van Steenkiste et al., 2019; Steenbrugge et al., 2018). Moreover, recent research in
reinforcement learning concludes that disentangling embeddings of skills allows for faster retraining
and better generalization (Petangoda et al., 2019). Finally, disentanglement seems also to be positively correlated with fairness when sensitive variables are not observed (Locatello et al., 2019a).
On the other hand, some empirical studies suggest that one should be cautious while interpreting the
properties of disentangled representations. For instance, the latest studies in the unsupervised learning domain point that increased disentanglement does not lead to a decreased sample complexity in
downstream tasks (Locatello et al., 2019b).
Another key challenge in studying disentangled representations is the fact that measuring the quality
of the disentanglement is a nontrivial task (Do & Tran, 2020; Eastwood & Williams, 2018; Kim
& Mnih, 2018), especially in a unsupervised setting (Locatello et al., 2019b). This motivates the
research on practical advantages of disentanglement representations and their impact on the studied
problem in possible future applications, which is the main focus of our work in the case of multi-task
learning.
2.2 MULTI-TASK LEARNING
Multi-task learning aims at simultaneously solving multiple tasks by exploiting common information
(Ruder, 2017). The approaches used predominantly to this problem are soft (Duong et al., 2015)
and hard (Caruana, 1993) parameter sharing. In hard parameter sharing the weights of the model
are divided into those shared by all tasks, and task-specific. In deep learning, this idea is typically
implemented by sharing consecutive layers of the network, which are responsible for learning a
joint data representation. In soft parameter sharing each task is given a set of separate parameters.
The limitations are then imposed by information-sharing or regularizing the distance between the
parameters by adding an applicable loss to the optimization objective.
Multi-task learning is widely used in the Deep Learning community, for instance in applications
related to natural language processing (Liu et al., 2019), computer vision (Misra et al., 2016) or
molecular property prediction modeled by graph neural networks (Capela et al., 2019). One may
observe that the premises of multi-task and disentanglement learning are related to each other and
thus it is interesting to investigate whether the joint data representation obtained in a multi-task
problem exhibits some disentanglement-related properties.
-----
3 METHODS
In this section, we describe the methods and datasets used for conducting the experiments.
3.1 DATASET CREATION
In order to investigate the relationship between multi-task learning and disentanglement, we require
a dataset that fulfills two conditions:
1. It provides access to the true (disentangled) generative factors z from which the observations x are created.
2. It proposes multiple tasks for a supervised learner by providing labels yi which non-linearly
depend on the true factors z.
The first condition is required in order to measure how well the learned representations approximate the true latent factors z. Access to Shape: heartPosition X: 10px Randomly initializednetworks
the true factors allows for full control over the Position Y: 30px
experimental settings and permits a fair com- Scale: 1.0
parison through the use of supervised disentan- Generative function
glement metrics. Note that even though unsu- Supervised learner
pervised metrics have been proposed in the lit
|initialized andom ly orks R Shape: heart netw Position X: 10px Position Y: 30px Rotation: 0.2 Scale: 1.0 G enerative function|Supervised learner|
|---|---|
erature as well, they typically yield less reliable
Figure 1: The setting of our experiments. Given
results, as we further discuss in section 3.3.
a dataset of pairs (x, z) of observations and their
The second condition is needed to train a net- true generative factors, we generate a set of funcwork on multiple nontrivial tasks to approxi- tions h(z)i which are aimed to approximate realmate the real-world setting of multi-task learn- world supervised tasks. Then, we train a neural
ing. network fφ(x) in a multi-task regression setting
on pairs (x, h(z)). After the training, we investi
To our best knowledge, no nontrivial datasets
gate the hidden representations learned by fφ and
exist that would abide by both those require
explore their relation to true factors z.
ments. Most of the available disentanglement
datasets, such as dSprites, Shapes3D, and MPI3D do fulfill the first condition, as they provide pairs
(x, z) of observations and their true generative factors. However, those datasets do not offer any
type of challenging task on which our model could be trained. On the other hand, many datasets
used for supervised multi-task learning fulfill the second condition by providing pairs (x, y), but do
not equip the researcher with the latent factors z (ground truth), failing the first condition.
Thus, we aim to create our own datasets which fulfill both conditions by incorporating nontrivial
tasks into standard disentanglement datasets. Since in multi-task approaches one often tries to solve
tens of tasks at once, designing them by hand is infeasible and as such we decide to generate them
automatically in a principled way. In particular, since supervised learning tasks might be formalized
as finding a good approximation to an unknown function h(x) given a set of points (x, h(x)), we
generate random functions h(z) which are then used to obtain targets for our dataset (see Figure 1).
We require h(z) to be both nontrivial (i.e. non-linear and non-convex) and sufficiently smooth to
approximate the nature of real-life tasks. In order to find a family of functions that fulfills those
conditions, we take inspiration from the field of extreme learning, which finds that features obtained
from randomly initialized neural networks are useful for training linear models on various realworld problems (Huang et al., 2011). As such, randomly initialized networks should be able to
approximate these tasks up to a linear operation.
In particular, in order to generate the dataset, we define a neural network architecture h(z, θ). For
this purpose, we used an MLP with four hidden layers with 300 units, tanh activations, and an
output layer which returns a single number. Then we sample n weight initializations of this network
from the Gaussian distribution θi (0, 1), where i 1, . . ., n . Each of the networks h(z, θi)
obtained by random initialization defines a single task in our approach. Thus, for a given dataset ∼N _∈{_ _}_
_D = (x, z) containing observations and their true generative factors, we obtain a dataset for multi-_
task supervised learning by applying:
_D˜ = {(x, h(z)) | (x, z) ∈D} = {(x, y)},_
-----
where h(z) is a vector of stacked target values for each task, whose element i is given by h(z)i =
_h(z, θi)._
We use this data as a regression task, i.e. for a given neural network fφ parameterized by φ the goal
is to find:
arg minφ _∥fφ(x) −_ _y∥2[2][.]_
(x,y)
X∈D[˜]
We use this process to create multi-task supervised versions of dSprites, Shapes3D, and MPI3D,
with 10 tasks for each dataset.
3.2 MODELS
3.2.1 MULTI-TASK MODEL
We investigate the relation between disentanglement and multi-task learning based on a hard parameter sharing approach. In this setting, several consecutive hidden layers of the model are shared
across all tasks in order to produce a joint data representation. This representation is then propagated
to separate task-specific layers which are responsible for computing the final predictions.
In particular, we use a network consisting of
a shared convolutional encoder and separate
fully-connected heads for each of the tasks. The
encoder learns the joint representation by transforming the inputs into a d-dimensional latent
space. [1] The heads are implemented by 4-layer
MLPs with ReLU activations, in order to match
the capacity of the networks used for task generating functions hi(x). This overview of the
model is illustrated in Figure 2.
3.2.2 AUTO-ENCODER MODEL
Input
|mputing the final predictions.|ictions.|Col3|Col4|
|---|---|---|---|
|Task 1 Head (FCN) Task 2 Head convolutional (FCN) Latent encoder Task Head (FCN)|Task 1 Head (FCN)|||
|||||
||Task 2 Head (FCN)|||
||Task Head (FCN)|||
|||||
erating functions hi(x). This overview of the Figure 2: The model used for multi-task trainmodel is illustrated in Figure 2.
ing. The convolutional encoder E(x) transforms
the input data x to a latent representation ˜z. The
3.2.2 AUTO-ENCODER MODEL parameters of the encoder are shared across all
tasks. Next, the produced representation is passed
In the second part of our experiments we
to the task-specific heads, which are implemented
want to understand if disentangled represen
by fully-connected networks (FCN).
tation provides some benefits for the multitask problem. In order to produce disentangled representations, we decided to use three different representation-learning algorithms: a vanilla auto-encoder, the (beta)-variational autoencoder (Kingma & Welling, 2013; Higgins et al., 2017) and FactorVAE (Kim & Mnih, 2018).
All these variants of the auto-encoder architecture encompass a similar framework. An auto-encoder
imposes a bottleneck in the network which forces a compressed knowledge representation of the
original input. In some variants of those models, we additionally try to constrain the latent variables
to be highly informative and independent which further correlates to disentanglement, e.g. in models
like β-VAE and FactorVAE. We use latent representations from these models to train task-specific
heads and evaluate if disentanglement helped to decrease an error for that task.
The vanilla auto-encoder is also used in Section 4.2, where we add a decoder with transposed convolutions to pre-trained encoders from Section 4.1. This treatment is aimed to decode information
for particular encoders in the most efficient way. As such, we find auto-encoders to be a useful tool
for investigating disentanglement.
3.3 DISENTANGLEMENT METRICS
Measuring the qualitative and quantitative properties of the disentanglement representation discovered by the model is a nontrivial task. Due to the fact that the true generating factors of a given
1We provide the full model summary in Appendix A. The architecture of the encoder follows the one from
(Abdi et al., 2019), which adopts the work of (Locatello et al., 2019b) for the pytorch package. We use the
[implementations from https://github.com/amir-abdi/disentanglement-pytorch.](https://github.com/amir-abdi/disentanglement-pytorch)
-----
dataset are usually unknown, one may assume that decomposition can be obtained only to some
extend.
Commonly used unsupervised metrics are based on correlation coefficients which measure the intrinsic dependencies between the latent components. Such measures are widely used in the independent component analysis (Hyvarinen & Morioka, 2016; 2017; Hirayama et al., 2017; Brakel &
Bengio, 2017; Spurek et al., 2020; Bedychaj et al., 2020). However, uncorrelatedness does not imply
stochastical independence. Furthermore, metrics based on linear correlations may not be able to capture higher-order dependencies and are often ineffective in large dimensional or in over-determined
spaces. All this makes the use of such unsupervised metrics questionable.
An alternative solution would be to use supervised metrics, which usually are more reliable (Locatello et al., 2019b). This is of course only possible after assuming access to the true generative
factors. Such an assumption is rarely valid for real-world datasets, however, it is satisfied for synthetic datasets. Synthetic datasets present therefore a reasonable baseline for benchmarking disentanglement algorithms.
Frequently used metrics which use supervision are mutual information gap (MIG) (Chen et al.,
2018), the FactorVAE metric (Kim & Mnih, 2018), Separated Attribute Predictability (SAP)
score (Kumar et al., 2018) and disenanglement-completness-informativeness (DCI) (Eastwood &
Williams, 2018). In order to comprehensively assess the level of disentanglement in our experiments, we have decided to use all of the above-mentioned metrics to validate our results. A more
detailed description of those metrics is available in Appendix B.
4 RESULTS AND DISCUSSION
In this section, we describe the performed experiments and discuss the obtained results. For more
details on the training regime and experimental setup please refer to Appendix C.
4.1 DOES HARD PARAMETER SHARING ENCOURAGE DISENTANGLEMENT?
One of the most common approaches to multi-task learning is hard parameter sharing. The key
challenge in this method is to learn a joint representation of the data which is at the same time informative about the input and can be easily processed in more than one task. It is therefore tempting to
verify whether disentanglement arises in those representations implicitly, as a consequence of hard
parameter sharing.
In order to investigate this problem we build a simple multi-task model described in Section 3.2 and
evaluate it on the three datasets discussed in Section 3.1: dSprites, Shapes3D, and MPI3D, each with
10 artificial tasks. After the training is complete, we calculate each of the disentanglement metrics
described in Section 3.3 on the latent representation of the input data[2]. We compare the obtained
results with the same metrics computed for an untrained (randomly initialized) network and for
single-task models. In all the cases we use the same architecture and training regime. Note that in
the single-model scenario we train a separate model for each of the 10 tasks, which is implemented
by utilizing only one, dedicated head in the optimization process. We train all models three times,
using a different random seed in the parameters initialization procedure. We report the mean results
and standard deviations in Figure 3.
We observe that disentanglement metrics computed for the representations obtained in the multi-task
setting are typically significantly better than the values obtained for single-task or random representations. Note that even the maximum mean result over all ten single-task models is in almost every
case further than one standard deviation from the multitask mean. Moreover, this is true for all the
tested datasets.
Let us also point out that instead of using separate heads for each of the tasks in the multi-task model
one could simply use one head with the output dimension equal to the number of tasks and perform
standard multivariate regression (with no parameter sharing). As presented in Figure 4, the latent
representations emerging in such a scenario are less disentangled (in terms of the considered metrics)
[2We use the implementations of Locatello et al. (2019b), which are available at https://github.com/](https://github.com/google-research/disentanglement_lib)
[google-research/disentanglement_lib](https://github.com/google-research/disentanglement_lib)
-----
Figure 3: Different disentanglement metrics computed for random (untrained), single-task and
multi-task models evaluated on the three datasets described in Section 3.1. The higher the value
the better. For the single-task scenario, we report the mean over all task-specific models. Note
that in almost every case the multi-task representations (red bars) outperform the random or singletask representations (dark-gray bars and light gray bars, respectively). Additionally, for single-task
models, we report the maximal and minimal values over all tasks to show that the performance on
multi-task does not rely on any single ’lucky’ task. For tabulated results please refer to Appendix E.
Figure 4: Different disentanglement metrics computed for the multi-task setting with one head
shared between all tasks (one-head) and separate head for every task (multi-head), evaluated on
the three datasets described in Section 3.1. The higher the value the better. One may observe that
multi-head representations perform better than the ones obtained in the standard, one-head multivariate regression task. For tabulated results please refer to Appendix E.
than the representations obtained when utilizing hard parameter sharing. However, the achieved
values are still better than in single-task models. This suggests that even though the increase in
the metrics may be partially caused by simply training the network on higher-dimensional targets,
the positive influence of hard parameter sharing cannot be ignored. This advocates in favor of the
hypothesis that multi-task representations are indeed more disentangled than the ones arising in
single-task learning.
4.2 WHAT ARE THE PROPERTIES OF THE LEARNED REPRESENTATIONS?
The previous section discussed the obtained representations by analyzing quantitative disentanglement metrics. Here, we provide more insights into the characteristics of latent encodings.
-----
Figure 5: UMAP embeddings of the latent representations of the Shapes3D test dataset obtained
for different models. Change of the color within one subplot presents the change in one particular ground truth component. The embeddings obtained by the multi-task model seem to be most
semantically meaningful. See Appendix D for plots for other datasets.
4.2.1 UMAP EMBEDDINGS
In order to gain intuition behind the differences between the representations obtained in the previous experiment we compute a 2D-embedding of the latent encodings using the UMAP algorithm (McInnes et al., 2018). The results are presented in Figure 5.
The embeddings obtained for the multi-task representations are much more semantically meaningful, with easily distinguishable separate clusters. Moreover, the position and internal structure of
the clusters correspond to different values of the true factors. This cannot be observed for the untrained or single-task representations, suggesting that the multi-task representations are indeed more
successful in encompassing the information about the real values of the generative sources of the
data.
input random single-task multi-task
Shapes3D
Figure 6: Reconstructions obtained by the decoders trained on random, single-task, and multi-task
encoding. For reference, we provide the original input images in the first row. The quality of the
reconstruction for the random and single-task representation is very poor. Contrary, the multi-task
encoder provided a latent space that can be successfully decoded into images that closely resemble
the corresponding examples from the input. Thus, we conclude that the multi-task representations
are more informative about the data and provide better compression. See Appendix D.2 for reconstructions for other datasets.
4.2.2 LATENT SPACE TRAVERSAL
Providing qualitative results of the retrieved factors is a common practice in disentanglement learning (Locatello et al., 2019c; Kumar et al., 2017; Sanchez et al., 2019; Sorrenson et al., 2020; Locatello et al., 2019b). In particular, visual presentation of the interpolations over the latent space
allows assessing — from a human perspective — the informativeness and decomposition of the ob
-----
tained representations. Note that such analysis is possible only after adding and training a suitable
decoder network, which maps the retrieved factors back to the image space.
In our setting, the decoder mirrors the architecture of the encoder (the convolutions are replaced by transposed convolutions of the same
size — see Appendix A). Given the latent representations as an input, the decoder optimizes
the reconstruction error (as measured by MSE)
between its outputs and the original images. We
train three separate decoders corresponding to
the different encoders from the previous section
— a randomly initialized encoder, an encoder
produced by one of the single-task models, and
a multi-task encoder.
First, let us discuss the reconstruction quality
achieved by each of the tested decoders. Results of this experiment are presented in Figure 6[3]. Reconstructions produced for the multitask encodings are clearly superior to the ones
obtained for the single-task encodings. In the
first case, the resulting images are sharp and
contain almost no noise. In contrast, the single task reconstructions are blurry and similar
to the ones produced for the randomly initialized encoder. We would like to emphasise that
all the decoders used the same architecture and
that during their optimization the parameters
of the corresponding encoders were kept fixed.
Therefore the quality of the reconstruction is an
important property of a latent representation, as
it allows us to assess the compression capacity
of the representation. From this perspective, the
compression obtained in the multi-task scenario
is much more informative about the input than
in the single-task scenario.
Another approach to the visualisation of the
latent variables is to perform interpolations
(traversals) in the latent space. We start by selecting a random sample from the dataset and
compute its encoding ˜z ∈ R[d]. By modifying
one of the components of vector ˜z from −1 to 1
with 0.1 step and leaving the d − 1 unchanged,
we produce a traversal along that particular factor. We repeat this procedure for all the factors in order to capture their impact on the decoded example. Results of such traverses for
the dSprites dataset are shown in Figure 7.
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
(a) Random encoder
0
(b) Single task encoder
0
(c) Multi-task encoder
Figure 7: Traverses over latent variable produced
for a given architecture. The same example was
used in all three traverses. The second row of
each image shows how the decoder reconstructed
this example in a particular setting. The rest of
the factors come from latent space generated from
each encoder. Visualization of components from
the multi-task encoder are sharp and distinguish
the generating factors distinctly. The same cannot
be said about the latent factors in single-task and
random encoders, which are blurry and disconnected from any interpretable ground truth factors.
Please refer to Appendix D for the results of the
traversals over other datasets.
Note that since the models were not trained directly for disentanglement but only to solve a supervision task, it is not surprising that the representations are not as clearly factorized as in specialized
methods such as FactorVAE. However, for the multi-task model, certain latent dimensions still appear to be disentangled and one can easily spot the difference in quality between the single and
multi-task representations. In the multi-task traversals, we can notice components that are responsible for the position and scale of a given figure (in Figure 7c, consider the 5th and 7th factors,
respectively). In contrast, the results for single task representations demonstrate that even a slight
change in any of the single latent dimensions leads to a degradation of the reconstructed examples.
3Numerical values for reconstruction errors are presented in Appendix D.2.
-----
As expected, this effect is even more evident for the random (untrained) representations, where the
corruption over latent factor is even more prevalent than in the case of a single-task traversal.
4.3 DOES DISENTANGLEMENT HELP IN TRAINING MULTI-TASK MODELS?
In the previous sections, we studied whether multi-task learning encourages disentanglement. Here
we consider an inverse problem by asking whether using disentangled representation helps in multitask learning. To investigate this issue, we train an auto-encoder-based model devised specifically to
produce disentangled latent representations without access to the true latent factors. Next, we freeze
its parameters and use the encoder function to transform the inputs. The obtained latent encodings
are then passed directly to the heads of a multi-task network which minimizes the average regression
loss given the target values of the artificial tasks.
We consider three different auto-encoder-based algorithms described in Section 3.2.2: a vanilla autoencoder (AE), a variational auto-encoder (VAE), and the FactorVAE. The vanilla auto-encoder does
not directly enforce latent disentanglement during the training. In the VAE model, the prior normal
distribution with identity covariance matrix implies some disentanglement. Finally, FactorVAE introduces a new module to the VAE architecture that explicitly induces informative decomposition.
Therefore, the representations obtained for each subsequent model should be also naturally ordered
by the level of the achieved disentanglement. For the exact values of the calculated metrics please
refer to Appendix F. In addition, we also study a scenario in which we explicitly provide the true
source factors. We trained all regression models three times, using a different random seed in the
parameters initialization procedure.
Table 1: RMSE of multi-task networks trained on latent
Table 1 summarizes the performance
representations obtained by different auto-encoder-based
of the multi-task model trained on the
methods. For comparison, we added the model trained on
representations obtained for the above
ground truth factors. The best results are bolded, and best
discussed methods. Although the rep
out of auto-encoder architectures underlined.
resentations obtained from FactorVAE
are better (see, for instance, MIG or
DCI measures in Appendix F) than
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|Ground Truth|150.235 ± 3.754 72.979 ± 0.193 108.568 ± 0.285|
|---|---|
others on Shapes3D and MPI3D and
|AE VAE FactorVAE|80.062 ± 0.341 114.939 ± 0.160 150.190 ± 0.097 63.260 ± 0.260 132.072 ± 0.169 194.865 ± 15.61 91.937 ± 0.199 118.396 ± 0.423 151.646 ± 0.336|
|---|---|
being second on dSprites. Note that these results coincide with observations presented in the literature. For example, (Locatello et al., 2019b) compared different models that enforce disentanglement
during the training and showed that even a high value of that property within the factors do not
constitute a better model performance. However, in two out of three datasets, the use of the ground
true factors seems to significantly improve the obtained results. This may suggest that the representations produced by the considered disentanglement methods are not fully factorized. It is therefore
inconclusive whether the discrepancy between the obtained results is due to the shortcomings of the
used methods or a manifestation of the impracticality of disentanglement.
5 CONCLUSIONS
In this paper, we studied the relationship between multi-task and disentanglement representation
learning. A fair evaluation of our hypothesis is impossible on real-world datasets, without provided
ground truth factors. To evaluate our results we had to introduce synthetic datasets that contain
all necessary properties to be seen as a benchmark in this field. Next, we studied the effects of
multi-task learning with hard parameter sharing on representation learning. We found that nontrivial
disentanglement appears in the representations learned in a multi-task setting. Obtained factors have
intuitive interpretations and correspond to the actual ground truth components. Finally, we inverted
the question and investigated the hypothesis that disentangled representation is needed for multitask learning, the results however are not conclusive. We found out that multi-task models benefit
from disentanglement only on specific datasets. However, we cannot name an indicator of when this
unambiguously applies.
-----
REFERENCES
Amir H. Abdi, Purang Abolmaesumi, and Sidney Fels. Variational learning with disentanglementpytorch. arXiv preprint arXiv:1912.05184, 2019.
Andrzej Bedychaj, Przemysław Spurek, Aleksandra Nowak, and Jacek Tabor. Wica: nonlinear
weighted ica, 2020.
Yoshua Bengio. Deep learning of representations: Looking forward, 2013.
Philemon Brakel and Yoshua Bengio. Learning independent features with adversarial nets for nonlinear ica, 2017.
Fabio Capela, Vincent Nouchi, Ruud Van Deursen, Igor V Tetko, and Guillaume Godin. Multitask
learning on graph neural networks applied to molecular property predictions. _arXiv preprint_
_arXiv:1910.13124, 2019._
Richard Caruana. Multitask learning: A knowledge-based source of inductive bias. In Proceedings
_of the Tenth International Conference on Machine Learning, pp. 41–48. Morgan Kaufmann, 1993._
Ricky TQ Chen, Xuechen Li, Roger B Grosse, and David K Duvenaud. Isolating sources of disentanglement in variational autoencoders. In Advances in neural information processing systems,
pp. 2610–2620, 2018.
Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, and Pieter Abbeel. Infogan: Interpretable representation learning by information maximizing generative adversarial nets.
_Advances in neural information processing systems, 29:2172–2180, 2016._
Laurent Dinh, David Krueger, and Yoshua Bengio. Nice: Non-linear independent components estimation. arXiv preprint arXiv:1410.8516, 2014.
Kien Do and Truyen Tran. Theory and evaluation metrics for learning disentangled representations.
[In International Conference on Learning Representations, 2020. URL https://openreview.net/](https://openreview.net/forum?id=HJgK0h4Ywr)
[forum?id=HJgK0h4Ywr.](https://openreview.net/forum?id=HJgK0h4Ywr)
Long Duong, Trevor Cohn, Steven Bird, and Paul Cook. Low resource dependency parsing: Crosslingual parameter sharing in a neural network parser. In Proceedings of the 53rd Annual Meeting
_of the Association for Computational Linguistics and the 7th International Joint Conference on_
_Natural Language Processing (Volume 2: Short Papers), pp. 845–850, 2015._
Cian Eastwood and Christopher K. I. Williams. A framework for the quantitative evaluation of
disentangled representations. In International Conference on Learning Representations, 2018.
[URL https://openreview.net/forum?id=By-7dz-AZ.](https://openreview.net/forum?id=By-7dz-AZ)
I. Higgins, Lo¨ıc Matthey, A. Pal, C. Burgess, Xavier Glorot, M. Botvinick, S. Mohamed, and
Alexander Lerchner. beta-vae: Learning basic visual concepts with a constrained variational
framework. In ICLR, 2017.
Junichiro Hirayama, AJ Hyvarinen, and Motoaki Kawanabe. Splice: Fully tractable hierarchical extension of ica with pooling. In Proceedings of the International Conference on Machine Learning,
volume 70, pp. 1491–1500. Machine Learning Research, 2017.
Jun-Ting Hsieh, Bingbin Liu, De-An Huang, Li F Fei-Fei, and Juan Carlos Niebles. Learning to decompose and disentangle representations for video prediction. In Advances in Neural Information
_Processing Systems, pp. 517–526, 2018._
Guang-Bin Huang, Dian Hui Wang, and Yuan Lan. Extreme learning machines: a survey. Interna_tional journal of machine learning and cybernetics, 2(2):107–122, 2011._
Aapo Hyvarinen and Hiroshi Morioka. Unsupervised feature extraction by time-contrastive learning
and nonlinear ica. In Advances in Neural Information Processing Systems, pp. 3765–3773, 2016.
AJ Hyvarinen and Hiroshi Morioka. Nonlinear ica of temporally dependent stationary sources.
Proceedings of Machine Learning Research, 2017.
-----
Hyunjik Kim and Andriy Mnih. Disentangling by factorising. arXiv preprint arXiv:1802.05983,
2018.
Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint
_arXiv:1412.6980, 2014._
Diederik P Kingma and Max Welling. Auto-encoding variational bayes. _arXiv preprint_
_arXiv:1312.6114, 2013._
Abhishek Kumar, Prasanna Sattigeri, and Avinash Balakrishnan. Variational inference of disentangled latent concepts from unlabeled observations. arXiv preprint arXiv:1711.00848, 2017.
Abhishek Kumar, Prasanna Sattigeri, and Avinash Balakrishnan. Variational inference of disentangled latent concepts from unlabeled observations, 2018.
Brenden M Lake, Tomer D Ullman, Joshua B Tenenbaum, and Samuel J Gershman. Building
machines that learn and think like people. Behavioral and brain sciences, 40, 2017.
Zachary C Lipton. The mythos of model interpretability. Queue, 16(3):31–57, 2018.
Xiaodong Liu, Pengcheng He, Weizhu Chen, and Jianfeng Gao. Multi-task deep neural networks
for natural language understanding. arXiv preprint arXiv:1901.11504, 2019.
Francesco Locatello, Gabriele Abbati, Thomas Rainforth, Stefan Bauer, Bernhard Sch¨olkopf, and
Olivier Bachem. On the fairness of disentangled representations. In Advances in Neural Informa_tion Processing Systems, pp. 14611–14624, 2019a._
Francesco Locatello, Stefan Bauer, Mario Lucic, Gunnar R¨atsch, Sylvain Gelly, Bernhard
Sch¨olkopf, and Olivier Bachem. Challenging common assumptions in the unsupervised learning of disentangled representations, 2019b.
Francesco Locatello, Michael Tschannen, Stefan Bauer, Gunnar R¨atsch, Bernhard Sch¨olkopf,
and Olivier Bachem. Disentangling factors of variation using few labels. _arXiv preprint_
_arXiv:1905.01258, 2019c._
Jianxin Ma, Chang Zhou, Peng Cui, Hongxia Yang, and Wenwu Zhu. Learning disentangled representations for recommendation. In Advances in neural information processing systems, pp.
5711–5722, 2019.
Leland McInnes, John Healy, and James Melville. Umap: Uniform manifold approximation and
projection for dimension reduction. arXiv preprint arXiv:1802.03426, 2018.
Ishan Misra, Abhinav Shrivastava, Abhinav Gupta, and Martial Hebert. Cross-stitch networks for
multi-task learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern
_Recognition, pp. 3994–4003, 2016._
Janith C. Petangoda, Sergio Pascual-Diaz, Vincent Adam, Peter Vrancx, and Jordi Grau-Moya.
Disentangled skill embeddings for reinforcement learning, 2019.
Sebastian Ruder. An overview of multi-task learning in deep neural networks. _arXiv preprint_
_arXiv:1706.05098, 2017._
Eduardo Hugo Sanchez, Mathieu Serrurier, and Mathias Ortner. Learning disentangled representations via mutual information estimation, 2019.
J¨urgen Schmidhuber. Learning factorial codes by predictability minimization. Neural computation,
4(6):863–879, 1992.
Bernhard Sch¨olkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and Joris Mooij.
On causal and anticausal learning. arXiv preprint arXiv:1206.6471, 2012.
Peter Sorrenson, Carsten Rother, and Ullrich K¨othe. Disentanglement by nonlinear ica with general
incompressible-flow networks (gin), 2020.
-----
Przemysław Spurek, Aleksandra Nowak, Jacek Tabor, Łukasz Maziarka, and Stanisław Jastrzebski.
Non-linear ica based on cramer-wold metric. In International Conference on Neural Information
_Processing, pp. 294–305. Springer, 2020._
Xander Steenbrugge, Sam Leroux, Tim Verbelen, and Bart Dhoedt. Improving generalization for abstract reasoning tasks using disentangled feature representations. _arXiv preprint_
_arXiv:1811.04784, 2018._
Sjoerd Van Steenkiste, Francesco Locatello, J¨urgen Schmidhuber, and Olivier Bachem. Are disentangled representations helpful for abstract visual reasoning? In Advances in Neural Information
_Processing Systems, pp. 14245–14258, 2019._
-----
A SUMMARY OF THE ARCHITECTURE OF THE MULTI-TASK MODEL
The architecture of the convolutional encoder E(x) is provided in Table 2, together with the architecture of the corresponding decoder, which was used in experiments in Section 4.2. For the
fully-connected heads, we used the same architecture as the one utilized during dataset creation,
which is presented in Table 3.
Table 2: The architecture of auto-encoder-based methods. Non-linearity in all layers is given by
ReLU function.
Encoder Decoder
|Type Kernel Stride Outputs|Type Kernel Stride Outputs|
|---|---|
|Conv 2d 4 2 32 Conv 2d 4 2 32 Conv 2d 4 2 64 Conv 2d 4 2 128 Conv 2d 4 2 256 Conv 2d 4 2 256 Dense output dim|Conv 2d 1 2 256 Conv Transpose 2d 4 2 256 Conv Transpose 2d 4 2 128 Conv Transpose 2d 4 2 128 Conv Transpose 2d 4 2 64 Conv Transpose 2d 4 2 64 Conv Transpose 2d 3 1 num channels|
|---|---|
Table 3: The architecture of a single fully-connected head in the single- and multi-task neural network. We apply non-linearity (given by the ReLU function) after all layers except the last one.
Type Output shape
Dense 300
Dense 300
Dense 300
Dense 10
B DISENTANGLEMENT METRICS
In our experiments, we decided to use four measures of disentanglement to comprehensively validate
our results. For the convenience of the reader, in this part of the appendix, we shortly describe the
used measures (for wider context we encourage the reader to refer to the original papers).
B.1 MUTUAL INFORMATION GAP (MIG)
MIG computes the mutual information between each of the ground truth components zi and the
disentangled factor ˜zj. The mutual information between zi and ˜zj is denoted by I(zi, ˜zj). Next,
the latent dimension with maximum mutual information score is identified for each of the retrieved
factor (denoted by I(zi, ˜zmax1 )), along with the second-best result of the same score (denoted by
respect to the total mutual information associated with the studied factor:I(zi, ˜zmax2 )). The difference between those values gives a gap, which finally is normalized with
MIG = _[I][(][z][i][,][ ˜]zmaxmj=11_ ) −[I][(][z]I[i]([,]z[ ˜]zij, ˜)zmax2 ) _._
Where m is the dimension of ground truth components space. To report one score we average the
P
MIG scores of all factors.
B.2 FACTORVAE METRIC
We start by normalizing retrieved factors by their respective standard deviation computed over the
dataset. For a subset of the dataset, a ground truth component is then randomly selected and fixed at
a random value. Variance is then computed over normalized retrieved factors in this subset. Next,
-----
the lowest variance factor — the one that should mostly resemble the fixed ground truth component
— is associated with that ground truth component.
This procedure with selecting the subsets and fixing one of its ground truth components is then
repeated multiple times (in our experiments 10000 times). As a result, the associations between
disentangled factor and ground truth component are used as inputs in a majority vote classifier.
FactorVAE metric is the mean accuracy of the classifier.
B.3 SEPARATED ATTRIBUTE PREDICTABILITY (SAP)
SAP attributes a score Sij to all pairs of ground truth components zi and disentangled factors ˜zj. For
continuous components, linear regression predicts the disentangled factors, and Sij is the coefficient
of determination (R[2]) of the regression. In the case of categorical features, SAP fits a decision tree
on ground truth components and reports the balanced classification accuracy. The final SAP score is
achieved by computing the difference between the two highest Sij values for all factors:
_SAP = [1]_
_Simax1_ _Simax2_ _,_
_i=1_ _−_
X
where n is the dimension of ground truth components space, Simax1 is the highest score for component zi and Simax2 is the second highest score for the same component.
B.4 DISENTANGLEMENT, COMPLETENESS, AND INFORMATIVENESS (DCI)
Unlike previous measures, DCI is a complete framework that allows verifying several properties
of the achieved representation. Disentanglement and completeness are estimated by inspecting the
regressor’s parameters to derive predictive importance weights Rij for each pair (zi, ˜zj) of ground
truth zi and retrieved ˜zj components.
The completeness for ground truth component zi is given by
_Ci = 1 +_ _pij logn pij,_
_j=1_
X
where m stands for ground truth dimension and pij is the probability that disentangled factor ˜zj is
important to predict zi. These probabilities are obtained by dividing each importance weight by the
sum of all importance weights related to a given component:
_Rij_
_pij =_ _m_ _._
_k=1_ _[R][ik]_
The final compactness score is an average of compactness scores over all components.
P
Disentanglement for retrieved factor ˜zj is given by
_Dj = 1 +_ _pij logd pij_
_i=1_
X
where d is the dimension of the latent space and pij is the probability that the latent factor ˜zj is
important to predict only the component zi. Analogously to completeness, those probabilities are
normalized with respect to potentially disentangled factors:
_Rij_
_pij =_ _d_ _._
_k=1_ _[R][kj]_
The final disentanglement score is a weighted average of the individual disentanglement scores:
P
_n_ _d_
_i=1_ _[R][ij]_
_D =_ _µjDj, where µj =_ _n_ _d_ _._
_j=1_ _kP=1_ _i=1_ _[R][ik]_
X
If a disentangled variable ˜zi is irrelevant for predicting zPj, then itsP _µi (and thus contribution to the_
overall disentanglement) will be near zero.
Finally, the prediction error of the regressor measures the informativeness of the representation. Normalized inputs and outputs allow to compute the estimation error for a completely random mapping
and use it to normalize the score between 0 and 1.
-----
C TRAINING REGIME AND EXPERIMENTAL SETUP
C.1 THE MULTI-TASK MODEL — EXPERIMENT 4.1
We train the multi-task model to minimize the sum of the task errors. The training is performed
for 200 epochs with learning rate 0.001 and batch size 256, by using the AdaM optimizer (Kingma
& Ba, 2014) with β1 = 0.9 and β2 = 0.999. We repeat this procedure three times, changing the
random seed initialization, and report the mean and average values of the disentanglement metrics.
C.2 LATENT VISUALISATIONS — EXPERIMENT 4.2
The encoder architecture was taken from the experiments in Section 4.3. The multi-task model for
each experiment was randomly selected from one of the seeds from the 10 tasks setting. Additionally, one of the single-task encoders was selected out of the trained ones for the same seed. The
random encoder was initialized by the default initialization used by the pytorch library.
The decoder architecture was optimized by minimizing the mean square error between the decoded
and input image. The training was performed over 500 epochs. We used mini-batches of 64 images
and gradually reduced learning rate starting from 0.0002, with a reduction of 50% every 100 epochs.
C.3 CLASSIFICATION BASED ON LATENT FACTORS — EXPERIMENT 4.3
We used the same auto-encoder and multi-task architectures like the one used in previous experiments (and defined in Section A), however with non-linearity given by tanh function. We trained all
auto-encoders for 100 epochs, using batch size 64, learning rate 0.0001, Adam optimizer (Kingma
& Ba, 2014) and latent dimension equal to 8. Other hyperparameters settings were adapted from
(Abdi et al., 2019). Multi-task networks were trained for 30 epochs, using batch size 64, learning
rate 0.0001, and adam optimizer. In order to average the scores over different runs, we repeated the
multi-task network training 3 times.
D VISUALISATIONS OF DECODED REPRESENTATIONS
D.1 UMAP EMBEDDINGS
In order to visualize the latent representations obtained for the random (untrained), single-task, and
multi-task models we embed them into a two-dimensional space by using the UMAP algorithm. The
results are shown in Figure 8. It may be observed that the embeddings obtained for the multi-task
representations are much more semantically meaningful. This is especially evident for the dSprites
and Shapes3D datasets. The MPI3D dataset is a significantly more difficult problem, and although
the multi-task embeddings seem to be correlated to some of the true factors, the difference is not as
visible in this case.
D.2 RECONSTRUCTIONS
As described in Section 4.2, we trained decoders over various latent spaces produced by the encoders
in the experiment from Section 4.1. We provide the numerical values of the reconstruction error in
Table 4 and qualitative images of the reconstructed examples in Figure 9. It may be observed that
the latent representations produced by random and single task encoders do not allow the decoder to
successfully restore the input examples. Moreover, the decoder trained on single-task latent is even
worse (in the case of reconstruction) than the random one.
Table 4: Test reconstruction error between the decoded images and the original input images.
random single-task multi-task
dSprites 308.04 326.30 35.97
Shapes3D 0.044 0.082 0.008
MPI3D 0.0021 0.0061 0.0009
-----
Table 5: The exact values of the metrics computed in the experiment from Section 4.1.
(a) MIG
(b) Factor VAE metric
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.00 ± 0.00 0.31 ± 0.03 0.35 ± 0.04 0.26 ± 0.01 0.50 ± 0.11 0.42 ± 0.02|0.00 ± 0.00 0.27 ± 0.03 0.31 ± 0.01 0.23 ± 0.01 0.59 ± 0.04 0.44 ± 0.06|0.00 ± 0.00 0.21 ± 0.02 0.23 ± 0.00 0.18 ± 0.01 0.36 ± 0.04 0.30 ± 0.04|
(d) disentanglement (DCI)
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.01 ± 0.01 0.01 ± 0.01 0.02 ± 0.00 0.01 ± 0.00 0.04 ± 0.02 0.02 ± 0.01|0.02 ± 0.01 0.01 ± 0.00 0.01 ± 0.00 0.00 ± 0.00 0.08 ± 0.02 0.02 ± 0.01|0.01 ± 0.00 0.01 ± 0.00 0.01 ± 0.00 0.00 ± 0.00 0.04 ± 0.02 0.02 ± 0.00|
(c) completeness (DCI)
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.02 ± 0.01 0.03 ± 0.01 0.05 ± 0.02 0.02 ± 0.00 0.08 ± 0.05 0.06 ± 0.02|0.03 ± 0.01 0.02 ± 0.01 0.03 ± 0.01 0.01 ± 0.00 0.14 ± 0.04 0.05 ± 0.02|0.05 ± 0.01 0.04 ± 0.01 0.06 ± 0.01 0.02 ± 0.00 0.10 ± 0.03 0.05 ± 0.01|
(e) informativeness (DCI)
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.23 ± 0.00 0.25 ± 0.03 0.30 ± 0.02 0.23 ± 0.01 0.41 ± 0.01 0.36 ± 0.01|0.40 ± 0.01 0.28 ± 0.01 0.30 ± 0.01 0.26 ± 0.01 0.53 ± 0.03 0.44 ± 0.03|0.43 ± 0.00 0.30 ± 0.01 0.31 ± 0.01 0.29 ± 0.00 0.53 ± 0.04 0.47 ± 0.04|
D.3 TRAVERSALS IN LATENT SPACE
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.02 ± 0.01 0.03 ± 0.01 0.04 ± 0.01 0.02 ± 0.01 0.09 ± 0.05 0.05 ± 0.01|0.03 ± 0.01 0.01 ± 0.01 0.03 ± 0.01 0.01 ± 0.00 0.15 ± 0.04 0.05 ± 0.03|0.05 ± 0.01 0.03 ± 0.01 0.04 ± 0.00 0.02 ± 0.00 0.10 ± 0.04 0.05 ± 0.01|
(f) SAP score
|model|dSprites|Sbapes3D|MPI3D|
|---|---|---|---|
|random single-mean single-max single-min multi-head one-head|0.00 ± 0.00 0.01 ± 0.01 0.02 ± 0.00 0.00 ± 0.00 0.01 ± 0.01 0.02 ± 0.02|0.01 ± 0.01 0.01 ± 0.00 0.01 ± 0.00 0.00 ± 0.00 0.04 ± 0.01 0.02 ± 0.01|0.00 ± 0.00 0.01 ± 0.00 0.01 ± 0.00 0.01 ± 0.00 0.02 ± 0.01 0.01 ± 0.01|
In parallel to the study of the quality of the reconstructions, we have also explored the traversals
in latent spaces. Given a latent representation ˜z of an arbitrary image x we compute the traverse
along each one of the components of ˜z, as described in Section 4.2. This traversal represents how
the image changes if only one component is slightly modified. This procedure provides a visually
qualitative way of assessing the level of disentangled in the obtained representations.
In order to complement the discussion conducted in Section 4.2 we present here also the traversals
for the Shapes3D and MPI3D datasets (in Figures 11 and 12, respectively). One may observe that
the results align with the quantitative studies of disentangled metrics from Figure 3 — where we
showed that the most disentangled representation is obtained in the multi-task scenario. Note that
the most informative changes of a particular feature for a given object may be observed in multitask traversals. One may spot that object factors — although not totally disentangled — change
independently from each other.
The same is not true for single-task traversals. In the example from Shapes3D dataset (Figure 11),
we observe that the single-task traversals capture only the color change of the background wall. It is
also not surprising that the least informative traversal comes from the randomly initialized encoder.
E DISENTANGLEMENT AND HARD PARAMETER SHARING
In Section 4.1 we discuss the influence of hard parameter sharing on disentanglement learning. Here
we present the computed metrics for all models (including regression) in a tabulated manner in Table
5. In addition, we also present the average MSE loss on the test dataset in Figure 13.
F DISENTANGLED REPRESENTATION AS BASES FOR MULTI-TASK TRAINING
In Section 4.3 we took the opportunity to discuss how disentanglement influences multi-task training. In this section, we present numerical results of all computed disentanglement metrics across
trained encoders. It is not surprising that FactorVAE representations are most disentangled in the
-----
Table 6: Numerical results of disentanglement metrics for latent on which multi-task training was
performed.
(a) MIG
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|AE VAE FactorVAE|0.028 0.028 0.023 0.117 0.041 0.011 0.272 0.251 0.040|
|---|---|
(c) Factor VAE metric
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|AE VAE FactorVAE|0.566 0.565 0.297 0.710 0.564 0.323 0.622 0.690 0.310|
|---|---|
(e) disentanglement (DCI)
(b) SAP score
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|AE VAE FactorVAE|0.006 0.020 0.009 0.032 0.020 0.017 0.068 0.020 0.011|
|---|---|
(d) informativeness (DCI)
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|AE VAE FactorVAE|0.395 0.493 0.473 0.579 0.533 0.484 0.664 0.610 0.482|
|---|---|
(f) completeness (DCI)
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|Dataset|dSprites Shapes3D MPI3D|
|---|---|
|AE VAE FactorVAE|0.052 0.081 0.070 0.257 0.124 0.119 0.356 0.342 0.079|
|---|---|
|AE VAE FactorVAE|0.046 0.078 0.078 0.271 0.120 0.128 0.407 0.331 0.091|
|---|---|
predominant number of cases. What can be read as a surprise is that FactorVAE representations are
never the best in terms of the root mean square error metric of the model that was trained on them.
G INCREASING THE NUMBER OF TASKS
Apart from the tested in the main paper scenario with 10 tasks, we also conducted experiments with
varying number of tasks n from the list of [5, 10, 20, 30, 40, 50]. The results are presented in Figure
14. It is impossible to draw any clear conclusions from this results, as the results vary a lot. It
may be observed that in some cases increasing the number of tasks up to 30 leads to higher values of
selected metrics, but at the same time having a negative impact on the others (for instance, Shapes3D
and Factor VAE versus DCI disentanglement or DCI completeness). These discrepancies are also
not consistent between datasets (consider the top row for the dSprites dataset versus the Shapes3D).
H VARYING THE NUMBER OF USED GENERATING FACTORS
Apart from the presented in Section 4.3 approach which uses all of the factors to generate a task,
we also considered a scenario in which a random subset of the factors is sampled for each task, and
a scenario in which the tasks are generated from disjoint subsets (every odd task depends only on
the first half of the factors and every even task on the other half). We compare these approaches
in Figure 15. The computed disentanglement measures vary and the precise subset of incorporated
factors in the task generating procedure does not have any conclusive impact on the final quality of
the learned representation.
I NUMBER OF RETRIEVED COMPONENTS
In addition to the results presented in Section 4.3 we also compute the number of retrieved components and the mean correlation values between the retrieved components and the ground truth factors in Table 7. The results are computed for the representations obtained on the test splits for each
datasets used in the UMAP embedding experiment in Section 4.3. To get the number of retrieved
components for each of the component of the representation we compute the spearman correlation
with each of the ground true factor and choose the one for which the correlation is the largest and
statistically significant. We next return the number of unique components matched in this way.
-----
dataset factors retrived mean corr std corr min corr max corr
dsprites multi 4 0.385113 0.166741 0.079401 0.616765
dsprites single 4 0.165147 0.057721 0.085172 0.270090
shapes3d multi 5 0.518252 0.307997 0.041420 0.903151
shapes3d single 4 0.311432 0.287550 0.000000 0.793141
mpi3d multi 6 0.317056 0.142792 0.111870 0.585552
mpi3d single 5 0.202390 0.111612 0.061758 0.363220
Table 7: The number of factors retrived by each method (mulit for multi-task models and single
for a single task models) and the average/std/min and max correlation of the retrieved components
with the ground truth factors.
J PERFORMANCE ON SINGLE TASKS
In this section we provide the test losses on all tasks in Tables 8, 9, and 10 for the dSprites, Shapes3D,
and MPI3D datasets, respectively.
-----
loss0 loss1 loss2 loss3 loss4 loss5 loss6 loss7 loss8 loss9 total loss
task
1 77.70 426.78 229.77 582.61 206.75 252.31 300.37 183.85 184.42 175.34 261.99
2 256.64 76.76 229.58 582.95 206.88 252.19 300.19 183.89 184.38 175.47 244.89
3 256.57 426.46 86.85 582.97 206.77 252.34 300.20 183.90 184.24 175.44 265.57
4 256.47 426.55 229.78 87.29 206.84 252.30 300.25 184.02 184.44 175.42 230.34
5 256.76 426.27 229.83 582.84 68.08 252.19 300.53 184.05 184.56 175.52 266.06
6 256.65 426.59 229.90 582.76 206.93 111.47 300.55 184.04 184.56 175.59 265.90
7 256.32 426.19 229.88 583.29 206.79 252.33 79.23 183.94 184.29 175.65 257.79
8 256.34 426.81 229.73 582.06 206.82 252.37 300.31 75.45 184.32 175.37 268.96
9 256.67 426.13 229.68 583.59 206.82 252.41 300.04 183.98 74.28 175.42 268.90
10 256.48 426.91 229.67 582.41 206.78 252.36 300.38 183.89 184.32 84.22 270.74
multi-10 51.96 62.57 55.18 61.30 51.01 79.30 57.34 48.80 46.58 53.25 56.73
Table 8: The test MSE for the experiments from Section 4.3 for dSprites dataset.
loss0 loss1 loss2 loss3 loss4 loss5 loss6 loss7 loss8 loss9 total loss
task
1 21.23 333.18 121.04 230.44 422.14 212.84 235.46 152.17 143.39 227.64 209.95
2 199.57 22.85 121.07 230.50 422.43 212.75 235.65 152.17 143.40 227.78 196.82
3 199.50 333.30 13.18 230.46 422.27 212.83 235.42 152.16 143.39 227.68 217.02
4 199.47 333.41 121.01 15.90 422.50 212.79 235.37 152.15 143.38 227.63 206.36
5 199.44 333.21 121.05 230.33 17.24 213.09 235.59 152.14 143.49 227.69 187.33
6 199.50 333.14 121.04 230.46 422.13 16.98 235.45 152.16 143.41 227.62 208.19
7 199.52 333.24 121.02 230.49 422.26 212.83 22.49 152.17 143.43 227.62 206.51
8 199.51 333.30 121.03 230.45 422.25 212.78 235.45 18.72 143.39 227.63 214.45
9 199.57 333.28 121.00 230.45 422.24 212.87 235.47 152.15 19.97 227.70 215.47
10 199.57 333.27 121.04 230.39 422.54 212.91 235.44 152.14 143.41 24.79 207.55
multi-10 26.50 27.65 17.51 19.41 22.95 21.54 27.80 23.71 24.30 28.92 24.03
Table 9: The test MSE for the experiments from Section 4.3 for Shapes3D dataset.
loss0 loss1 loss2 loss3 loss4 loss5 loss6 loss7 loss8 loss9 total loss
task
1 132.96 293.81 237.19 255.46 218.69 201.70 407.40 279.86 181.91 433.82 264.28
2 172.04 165.08 237.15 255.15 218.70 201.74 407.28 279.80 181.91 433.36 255.22
3 171.97 293.79 130.41 255.34 218.68 201.74 407.48 279.96 181.91 433.63 257.49
4 172.04 293.48 237.21 133.63 218.73 201.80 407.47 279.93 181.89 433.71 255.99
5 172.03 293.67 237.29 255.34 158.39 201.82 407.38 279.89 181.90 433.51 262.12
6 171.94 293.77 237.15 255.31 218.83 117.03 407.56 279.66 181.91 433.91 259.71
7 171.98 293.75 237.25 255.36 218.76 201.64 149.68 279.85 181.93 433.50 242.37
8 171.97 293.77 237.18 255.44 218.69 201.79 407.22 145.36 181.90 433.69 254.70
9 171.98 293.72 237.18 255.32 218.69 201.73 407.43 279.94 156.73 433.62 265.64
10 171.93 293.76 237.23 255.27 218.69 201.72 407.49 279.92 181.92 131.88 237.98
multi-10 77.60 102.41 76.46 82.99 93.97 71.64 91.93 87.51 89.75 80.52 85.48
Table 10: The test MSE for the experiments from Section 4.3 for MPI3D dataset.
-----
(a) dSprites
(b) Shapes3D
(c) MPI3D
Figure 8: The UMAP embeddings obtained for untrained, single-task, and multi-task models on
different datasets (computed on the test splits). The change in color corresponds to a change in the
value of a selected true factor.
-----
dSprites Shapes3D MPI3D
input
random
single-task
multi-task
Figure 9: Reconstructions obtained during the experiments described in Section 4.2. The quality
of the reconstruction for all datasets behaves similarly. One may easily observe that the multi-task
encoder provided a latent space that can be successfully decoded into images that closely resemble
the corresponding examples from the input. This is not the case in single-task or random encoders.
-----
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
(a) Random encoder
0
(b) Single task encoder
0
(c) Multi-task encoder
Figure 10: Traverses for dSprites dataset over latent variable produced for a given architecture.
-----
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
(a) Random encoder
0
(b) Single task encoder
0
(c) Multi-task encoder
Figure 11: Traverses for MPI3D dataset over latent variable produced for a given architecture.
-----
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
input
recon
factor 1
factor 2
factor 3
factor 4
factor 5
factor 6
factor 7
factor 8
-1
(a) Random encoder
0
(b) Single task encoder
0
(c) Multi-task encoder
Figure 12: Traverses for Shapes3D datset over latent variable produced for a given architecture.
-----
Figure 13: The average MSE on the test set computed for the random (untrained) model, a singletask model, and the multi-task models: multi-head and one-head. In the single-task case, we report
the mean over all models for each task. The lower the value the better. As expected, the methods
which jointly optimize the tasks achieve the best results.
Figure 14: The disentanglement metrics computed for the multi-task model for different number of
tasks presented on the x-axis. Experiments for the mpi3d dataset with 40 tasks did not converge
(thus we observe a significantly lower values for this bar).
Figure 15: The disentanglement metrics on different factors splits in the multi-task seeting
-----