Disentanglement of Correlated Factors via Hausdorff Factorized Support
Karsten Roth, Mark Ibrahim, Zeynep Akata, Pascal Vincent*, Diane Bouchacourt*
International Conference on Learning Representations, ICLR
2023

Abstract

A grand goal in deep learning research is to learn representations capable of generalizing across distribution shifts. Disentanglement is one promising direction aimed at aligning a models representations with the underlying factors generating the data (e.g. color or background). Existing disentanglement methods, however, rely on an often unrealistic assumption - that factors are statistically independent. In reality, factors (like object color and shape) are correlated. To address this limitation, we propose a relaxed disentanglement criterion - the Hausdorff Factorized Support (HFS) criterion - that encourages a factorized support, rather than a factorial distribution, by minimizing a Hausdorff distance. This allows for arbitrary distributions of the factors over their support, including correlations between them. We show that the use of HFS consistently facilitates disentanglement and recovery of ground-truth factors across a variety of correlation settings and benchmarks, even under severe training correlations and correlation shifts, with in parts over +60% in relative improvement over existing disentanglement methods. In addition, we find that leveraging HFS for representation learning can even facilitate transfer to downstream tasks such as classification under distribution shifts. We hope our original approach and positive empirical results inspire further progress on the open problem of robust generalization.

Introduction

Disentangled representation learning is a promising path to facilitate reliable generalization to in- and out-of-distribution downstream tasks, on top of being more interpretable and fair (Bengio et al. 2013, Higgins et al. 2018, Locatello et al. 2019,2020). While various metrics have been proposed to measure disentanglement, the most commonly understood definition is as follows:


Definition. Assuming the data has been generated by a set of unknown, ground-truth latent factors, a representation is said to be disentangled if each factor is recovered in one and only one dimension of the representation.


The method by which to achieve this goal however, remains an open research question. Weak and semi-supervised settings, e.g. using paired data samples or auxiliary variables, can provably offer disentanglement and recovery of ground truth factors (e.g. Bouchacourt et al. 2018, Locatello et al. 2020, ...). But fully unsupervised disentanglement -- our focus in this study -- is in theory impossible to achieve in the general unconstrained nonlinear case (c.f. Hyvaerinen et al. 1999, Locatello et al. 2019). In practice however, the inductive biases embodied in common autoencoder architectures allow for effective practical disentanglement (Rolinek et al. 2019).


Perhaps more problematic, standard unsupervised disentanglement methods (s.a. β\beta-(TC)VAE, AnnealedVAE, DIP-VAE, ...) rely on an unrealistic assumption of statistical independence of ground truth factors. Real data however contains correlations (Traeuble et al. 2021). Even with well defined factors, such as object shape, color or background, correlations are pervasive: yellow bananas are more frequent than red ones; and cows are much more often on pasture than sand dunes. In more realistic settings where factors are correlated, prior work has shown existing disentanglement methods fail.


To address this limitation, we propose to relax the unrealistic assumption of statistical independence of factors (i.e. that they have a factorial distribution), and only assume the support of the factors' distribution factorizes -- a much weaker, but more realistic constraint. To visualize this, consider a dataset of animal images where background and animal type are heavily correlated (camels most likely on sand, and cows on grass):



text



Under the original assumption of factor independence, a model likely learns a shortcut solution where animal and landscape share the same latent correspondence (Beery et al. 2018). On the other hand with a factorized support, learned factors should be such that any combination of their values has some grounding in reality: a cow on sand is an unlikely, yet not impossible combination. We still rely, just as standard unsupervised disentanglement methods, on the inductive bias of encoder-decoder architectures to recover factors -- however, we expect our method to facilitate robustness to any distribution shifts within the support, as it makes no assumptions on the distribution beyond its factorized support.




Method

On this basis, we propose a concrete pairwise Hausdorff Factorized Support (HFS) training criterion to disentangle correlated factors, by aiming for all pairs of latents to have a factorized support. Specifically, we encourage a factorized support by minimizing a Hausdorff set-distance between the finite sample approximation of the actual support and its factorization (c.f. Huttenlocher et al. 1993, Rockafellar et al. 1998).

Setup.

To explain, we first describe the general setting: We are given a dataset D={xi}i=1N\mathcal{D}=\{\mathbf{x}^i\}_{i=1}^{N} (e.g. images), where each xi\mathbf{x}^i is a realization of a random variable, e.g., an image. We consider that each xi\mathbf{x}^i is generated by an unknown generative process, involving a ground truth latent random vector z\mathbf{z} whose components correspond to the dataset's underlying factors of variations (s.a. object shape, color, background, \ldots). This process generates an observation x\mathbf{x}, by first drawing a realization z=(z1,,zk)\mathbf{z}=(z_1,\ldots,z_k) from a distribution p(z)p(\mathbf{z}), i.e. zp(z)\mathbf{z} \sim p(\mathbf{z}). Observation x\mathbf{x} is then obtained by drawing xp(xz)\mathbf{x} \sim p(\mathbf{x}|\mathbf{z}).

Given D\mathcal{D}, the goal of disentangled representation learning can be stated as learning a mapping fϕf_\phi that for any x\mathbf{x} recovers as best as possible the associated z\mathbf{z} i.e. fϕ(x)E[zx]f_\phi(\mathbf{x}) \approx \mathbb{E}[\mathbf{z}| \mathbf{x}] up to a permutation of elements and elementwise bijective transformation.

Statistical Independence.

In unsupervised disentanglement, the z\mathbf{z} are unobserved, and both p(z)p(\mathbf{z}) and p(xz)p(\mathbf{x}|\mathbf{z}) are a priori unknown to us, though we might assume specific properties and functional forms.

Most unsupervised disentanglement methods follow the formalization of VAEs and employ parameterized probabilistic generative models of the form pθ(x,z)=pθ(z)pθ(xz)p_\theta(\mathbf{x}, \mathbf{z}) = p_\theta(\mathbf{z}) p_\theta(\mathbf{x} | \mathbf{z}) to estimate the ground truth generative model over z,x\mathbf{z},\mathbf{x}. As in VAEs, these methods make the strong assumption that ground truth factors are statistically independent,


p(z)=p(z1)p(z2)p(zk),\begin{equation} p(\mathbf{z})=p(z_1) p(z_2) \ldots p(z_k), \end{equation}

and conflate the goal of learning a disentangled representation with that of learning a representation with statically independent components. This assumption naturally translates to a factorial model prior pθ(z)p_\theta(\mathbf{z}).

Factorized Support.

Instead of assuming independent factors (i.e. a factorial distribution on z\mathbf{z} as noted above), we will only assume that the support of the distribution factorizes. Let us denote by S(p(z))\mathcal{S}(p(\mathbf{z})) the support of p(z)p(\mathbf{z}), i.e. the set {zZp(z)>0}\{\mathbf{z} \in \mathcal{Z} \,|\, p(\mathbf{z}) > 0 \}. We say that S(p(z))\mathcal{S}(p(\mathbf{z})) is factorized if it equals to the Cartesian product of supports over individual dimensions' marginals, i.e. if:


S(p(z))=S(p(z1))×S(p(z2))×...×S(p(zk))=defSX(p(z))\begin{equation} \mathcal{S}(p(\mathbf{z})) = \mathcal{S}(p(z_1))\times\mathcal{S}(p(z_2))\times ... \times\mathcal{S}(p(z_k)) \stackrel{\text{def}}{=} \mathcal{S}^X(p(\mathbf{z})) \end{equation}

where ×\times denotes the Cartesian product.


Of course, independence implies a factorized support, but not the other way - assuming a factorized support is thus a relaxation of the (unrealistic) assumption of factorial distribution, i.e. statistical independence of disentangled factors. Refer to the previous cartoon example, where the distribution of the two disentangled factors would not satisfy an independence assumption, but does have a factorized support. Informally the factorized support assumption is merely stating that whatever values z1z_1 and z2z_2, etc... may take individually, any combination of these is possible (even when not very likely).

Basic Hausdorff Factorization Objective.

For our objective, let us consider deterministic representations obtained by some encoder z=fϕ(x)\mathbf{z}=f_\phi(\mathbf{x}). We enforce the factorial support criterion on the aggregate distribution qˉϕ(z)=Ex[fϕ(x)]\bar{q}_\phi(\mathbf{z})=\mathbb{E}_\mathbf{x}[f_\phi(\mathbf{x})], where qˉϕ(z)\bar{q}_\phi(\mathbf{z}) is conceptually similar to the aggregate posterior qϕ(z)q_\phi(\mathbf{z}) in e.g. β\beta-TCVAE, though we consider points produced by a deterministic mapping fϕf_\phi rather than a stochastic one.

To now encourage support factorization, we now need some divergence or metric to tell us how far our encoder support S\mathcal{S} is from SX\mathcal{S}^X. Supports are sets, so it is natural to use a set distance such as the Hausdorff distance, giving


dH(S,SX)=max(supzSX[infzSd(z,z)],supzS[infzSXd(z,z)])=supzSX[infzSd(z,z)]\begin{equation} d_H(\mathcal{S}, \mathcal{S}^X) = \max\left(\sup_{\mathbf{z}\in\mathcal{S}^X}\left[\inf_{\mathbf{z}'\in\mathcal{S}} d(\mathbf{z},\mathbf{z}')\right], \sup_{\mathbf{z}\in\mathcal{S}}\left[\inf_{\mathbf{z}'\in\mathcal{S}^X} d(\mathbf{z},\mathbf{z}')\right]\right) =\sup_{\mathbf{z}\in\mathcal{S}^X}\left[\inf_{\mathbf{z}'\in\mathcal{S}} d(\mathbf{z},\mathbf{z}')\right] \end{equation}

with the second part of the Hausdorff distance equating to zero since SSX\mathcal{S}\subset\mathcal{S}^X.

Practical Implementation.

In practial settings with a finite sample of observations {x}iN\{\mathbf{x}\}_i^N, we further introduce a practical Monte-Carlo batch-approximation: with access to a batch of bb inputs X\mathbf{X} yielding bb kk-dimensional latent representations Z=fϕ(X)Rb×k\mathbf{Z} = f_\phi(\mathbf{X})\in\mathbb{R}^{b\times k}, we estimate Hausdorff distances using sample-based approximations to the support:


SZ and SXZ:,1×Z:,2×...×Z:,k={(z1,,zk),  z1Z:,1,,zkZ:,k}.\begin{equation} \mathcal{S} \approx \mathbf{Z} \text{ and } \mathcal{S}^X \approx \mathbf{Z}_{:, 1}\times\mathbf{Z}_{:, 2}\times...\times\mathbf{Z}_{:, k} = \{ (z_1, \ldots, z_k),\; z_1 \in \mathbf{Z}_{:, 1}, \ldots, z_k \in \mathbf{Z}_{:, k} \}. \end{equation}

Here Z:,j\mathbf{Z}_{:, j} must be understood as the set (not vector) of all elements in the jthj^\mathrm{th} column of Z\mathbf{Z}. Plugging into above equation yields:


d^H(Z)=maxzZ:,1×Z:,2×...×Z:,k[minzZd(z,z)]\begin{equation} \hat{d}_{H}(\mathbf{Z}) = \max_{\mathbf{z}\in \mathbf{Z}_{:, 1}\times\mathbf{Z}_{:, 2}\times...\times\mathbf{Z}_{:, k}} \left[\min_{\mathbf{z}'\in \mathbf{Z}} d(\mathbf{z},\mathbf{z}')\right] \end{equation}

where by noting zZ\mathbf{z}' \in \mathbf{Z} we consider the matrix Z\mathbf{Z} as a set of rows.

Sliced Hausdorff Factorized Support.

In high dimension, with many factors, the assumption that every combination of all latent values is possible might still be too strong an assumption. And even if we assumed all to be in principle possible, we can never hope to observe all in a finite dataset of realistic size due to the combinatorial explosion of conceivable combinations. However, it is statistically reasonable to expect evidence of a factorized support for all pairs of elements. To encourage such a pairwise factorized support, we minimize a sliced, pairwise Hausdorff estimate with the additional benefit of keeping computation tractable when kk is large:


d^H(2)(Z)=i=1k1j=i+1kmaxzZ:,i×Z:,j[minzZ:,(i,j)d(z,z)]\begin{equation} \hat{d}^{(2)}_{H}(\mathbf{Z}) = \sum_{i=1}^{k-1}\sum_{j=i+1}^k\max_{\mathbf{z}\in\mathbf{Z}_{:,i}\times\mathbf{Z}_{:,j}} \left[\min_{\mathbf{z}'\in \mathbf{Z}_{:, (i,j)}} d(\mathbf{z},\mathbf{z}')\right] \end{equation}

where Z:,(i,j)\mathbf{Z}_{:, (i,j)} denotes the concatenation of column ii and column jj, yielding again a set of rows.

Final objective: avoiding collapse and retaining input information.

We will be learning representations z=fϕ(x)\mathbf{z} = f_\phi(\mathbf{x}) by learning parameters ϕ\phi that optimize a training objective. Because the Hausdorff distance builds on a base distance d(z,z)d(\mathbf{z},\mathbf{z}'), if we were to minimize only this, it could be trivially minimized to 0 by collapsing all representations to a single point. Avoiding this can be achieved in several ways, s.a. by including a term that encourages the variance of z:,i\mathbf{z}_{:,i} to be above 1 (a technique used e.g. in self-supervised learning method VICReg \citep{bardes2022vicreg}) or -- more in line with traditional VAE variants for disentanglement -- by using a stochastic autoencoder (SAE) reconstruction error:


SAE(x;ϕ,θ)=Eqϕ(zx)[logpθ(xz)]\begin{equation} \ell_\mathrm{SAE}(\mathbf{x}; \phi, \theta) = - \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right] \end{equation}

where typically qϕ(zx)=N(fϕ(x),Σϕ(x))q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(f_\phi(\mathbf{x}), \Sigma_\phi(\mathbf{x})) with mean given by our deterministic mapping fϕf_\phi, Σϕ(x)\Sigma_\phi(\mathbf{x}) producing a diagonal covariance parameter, and e.g. logpθ(xz)=rθ(z)x2\log p_\theta(\mathbf{x}|\mathbf{z}) = \|r_\theta(z) - x \|^2 with rθr_\theta a parameterized decoder. The autoencoder term ensures representations fϕ(x)f_\phi(\mathbf{x}) retaining as much information as possible about x\mathbf{x} for reconstruction, preventing collapse of representations to a single point. A minimum scale can also be ensured by imposing by construction Σϕ(x)\Sigma_\phi(\mathbf{x}) to be above a minimal threshold.

Consequently, this gives a standard HFS objective:

LHFS(D;ϕ,θ)=EXbD[γd^H(2)(fϕ(X))+1bxXSAE(x;ϕ,θ)]\begin{equation} \textstyle\mathcal{L}_\mathrm{HFS}(\mathcal{D}; \phi, \theta) = \mathbb{E}_{\mathbf{X} \overset{b}{\sim} \mathcal{D}} \left[ \gamma \hat{d}^{(2)}_{H}(f_\phi(\mathbf{X})) + \frac{1}{b} \sum_{\mathbf{x}\in \mathbf{X}} \ell_\mathrm{SAE}(\mathbf{x}; \phi, \theta) \right] \end{equation}

or, alternatively, one may also leverage HFS as a regularizer alongside other disentanglement methods to reliably improve disentanglement - focusing more on support factorization than statistical independence, but being able to leverage it when possible (since there may also be pairs for which statistical independence is a suitable assumption).




Experiments

Across large-scale experiments on standard disentanglement benchmarks and novel extensions with correlated factors, HFS consistently facilitates disentanglement.


To begin, we first create various benchmarks that introduce increasingly more difficult artificial correlations between ground-truth factors. Specifically, given two ground truth factors z1z_1 and z2z_2, we set their joint sampling probability as

p(z1,z2)exp((z1f(z2))2/(2σ2))\begin{equation} p(z_1, z_2) \propto \exp\left(-(z_1 - f(z_2))^2/(2\sigma^2)\right) \end{equation}

This means that for lower scaling values σ\sigma, we get stronger correlations between factors, which we can also extend to multiple pairings. Doing so for multiple different benchmark datasets, and running existing disentanglement methods both with and without HFS regularization, as well as HFS as a standalone objective, shows significant improvements of in parts over +60%+60\% in disentanglement performance over baselines as measured by DCI-D (Eastwood et al. 2018):


text


This provides strong evidence towards the benefits of focusing on support factorization over hard statistical independence. Going even further, we now introduce an even larger range of possible training correlation. In addition, correlations are introduced in the test data as well, which allows us to produce artifical distribution shifts between training and testing that we can evaluate our model on. In doing so, we interestingly find that leveraging support factorization can provide increased generalization benefits for harder out-of-distribution shifts:


text


Finally, another interesting benefit is the fact that the significantly increased degree of unsupervised disentanglement also results in increased adaptation speeds when training a gradient-boosted decision tree on representations over increasingly reduced amounts of training data:


text



Conclusion

To avoid the unrealistic assumption of factors independence (i.e. factorial distribution) as in traditional disentanglement, which stands in contrast to realistic data being correlated, we thoroughly investigate an approach that only aims at recovering a factorized support. Doing so achieves disentanglement by ensuring the model can encode many possible combinations of generative factors in the learned latent space, while allowing for arbitrary distributions over the support -- in particular those with correlations. Indeed, through a practical criterion using pairwise Hausdorff set-distances -- HFS -- we show that encouraging a pairwise factorized support is sufficient to match traditional disentanglement methods. Furthermore we show that HFS can steer existing disentanglement methods towards a more factorized support, giving large relative improvements of over +60%+60\% on common benchmarks across a large variety of increasingly harder correlation shifts. We find this improvement in disentanglement across correlation shifts to be also reflected in improved out-of-distribution generalization especially as these shifts become more severe; tackling a key promise for disentangled representation learning.

(c) 2024 Explainable Machine Learning Munich Impressum