disentangled

a novel algorithm for disentangled learning

Disentangling Causal Mechanisms

How we built disentangled representations by explicitly partitioning latent spaces and obstructing unwanted correlations

When AI Learns the Wrong Things

Imagine training an AI model to recognize cows, only to discover it’s actually learned to identify green, grassy backgrounds instead of the animals themselves. Show it a car parked on a grassy field, and it confidently predicts “cow.” This isn’t just a hypothetical—it’s a real problem plaguing modern deep neural networks.

The core issue: Deep neural networks tend to entangle features in their internal representations, leading to:

  • Spurious correlations: Models learn shortcuts that don’t generalize
  • Black-box behavior: Predictions become uninterpretable
  • Non-compositional representations: Features can’t be independently manipulated
  • Poor generalization: Models fail on out-of-distribution data

This is where disentangled representation learning comes in—the quest to build AI systems that learn independent, interpretable features that mirror how we naturally understand the world.

Projection-Based Disentanglement

Most existing approaches to disentanglement, like β-VAEs and InfoGAN, rely on implicit constraints and information-theoretic objectives. We took a different path: what if we explicitly force independent mechanisms into separate parts of the latent space?

Our key insight: Use adversarial projection to actively obstruct unwanted correlations while preserving wanted information.

R-LACE

Our method builds on Relaxed Linear Adversarial Concept Erasure (R-LACE), originally designed for debiasing word embeddings. R-LACE works by solving a minimax game:

min max Σ ℓ(yₙ, g⁻¹(θᵀ P xₙ))
 θ   P

Where:

  • P is an orthogonal projection matrix that removes unwanted information
  • θ represents classifier parameters trying to recover that information
  • The game finds the optimal projection that maximally confuses classifiers

Our contribution: We adapted this adversarial projection approach to enforce disentanglement during autoencoder training.

The Method: Partitioned Latent Spaces with Adversarial Constraints

Experimental Setup: Colored MNIST

To demonstrate our approach, we designed a controlled experiment using Colored MNIST:

  • Dataset: MNIST digits randomly colored red, green, or blue
  • Goal: Partition latent space to encode digit and color independently
  • Challenge: Prevent the model from learning digit-color correlations
Examples from the Colored MNIST dataset showing digits 0-9 in red, green, and blue colors

Architecture: The Disentangled Autoencoder

class RLACE_AE:
    def __init__(self, rank=1, d=6):
        # Standard autoencoder components
        self.encoder = Encoder(encoded_space_dim=d)
        self.decoder = Decoder(encoded_space_dim=d)
        
        # Projection matrices for each latent partition
        self.P1 = torch.randn(d//2, d//2)  # Remove color info
        self.P2 = torch.randn(d//2, d//2)  # Remove digit info

Our architecture implements a partitioned latent space approach:

Architecture diagram showing the disentangled autoencoder with R-LACE
  1. Encoder: Maps images to d-dimensional latent vectors
  2. Partition: Split latent space into two d/2-dimensional subspaces
  3. Projection: Apply R-LACE to each partition:
    • Partition 1: Remove color information → pure digit encoding
    • Partition 2: Remove digit information → pure color encoding
  4. Decoder: Reconstruct images from projected latent representations

Training Process: Alternating Optimization

def solve_adv_game(self, dataloader, o_epochs=100, a_epochs=10):
    for o_step in range(o_epochs):
        # Train autoencoder for a_epochs with current projections
        for a_step in range(a_epochs):
            train_ae_loss = train_epoch_with_projection(
                self.encoder, self.decoder, self.P1, self.P2, 
                dataloader, torch.nn.MSELoss(), optimizer
            )
        
        # Update projections using R-LACE
        X = self.encoder(batch_data).detach()
        X1, X2 = X[:, :d//2], X[:, d//2:]  # Partition latent space
        
        # Solve adversarial games for each partition
        rlace_output1 = rlace(X1, color_labels, rank=1)
        rlace_output2 = rlace(X2, digit_labels, rank=1)
        
        self.P1 = rlace_output1.best_P
        self.P2 = rlace_output2.best_P

Key innovation: The alternating optimization between autoencoder training and projection matrix updates creates an information bottleneck that forces independent mechanisms into separate latent dimensions.

Illustration of R-LACE projection concept

Results: Achieving True Disentanglement

Our experiments revealed clear evidence of successful disentanglement:

Latent Partition Size Reconstruction Loss Digit Accuracy (Full) Color Accuracy (Full)
n=1 0.03 0.55 (random) 0.33 (random)
n=2 0.02 0.51 0.95
n=3 0.009 0.89 0.92
n=5 0.006 0.89 0.97

Critical insight: At least 3 dimensions per partition are needed for effective disentanglement. With n=1, R-LACE removes too much information; with n≥3, we achieve both good reconstruction and strong disentanglement.

Disentanglement Verification

To verify true disentanglement, we trained classifiers on each projected latent partition:

Partition Size Digit Encoding → Digit Acc Digit Encoding → Color Acc Color Encoding → Digit Acc Color Encoding → Color Acc
n=3 0.42 0.38 (near random) 0.32 (near random) 0.94
n=5 0.74 0.48 0.39 0.99

Success criteria:

  • High accuracy when predicting the “corresponding” feature (0.74 digit, 0.99 color)
  • Near-random accuracy when predicting the “opposite” feature (0.39-0.48)

Visual Evidence: Latent Space Visualization

The most compelling evidence comes from visualizing the learned latent spaces:

Latent space visualizations showing disentanglement: (a) vanilla autoencoder, (b) digit encoding dimensions, (c) color encoding dimensions

Before Disentanglement (Vanilla Autoencoder):

  • Digit and color information completely entangled
  • No clear separation between concepts

After Disentanglement:

  • Digit Partition: Clear digit clusters, no color separation
  • Color Partition: Clear color clusters, digits highly entangled

This visual separation confirms that our method successfully isolates independent causal mechanisms.

The Mathematics of Adversarial Projection

R-LACE solves a constrained minimax game to find optimal projection matrices:

P ∈ Pₖ ⟺ P = I_D - W^T W,  W ∈ R^(K×D), WW^T = Iₖ

Where P projects onto the orthogonal complement of a k-dimensional bias subspace. The projection neutralizes unwanted correlations while preserving other information.

Information Bottleneck Interpretation

Our alternating training creates an explicit information bottleneck:

  1. Autoencoder loss pressures the model to preserve all information needed for reconstruction
  2. R-LACE projection removes specific correlations between partitions
  3. Competition forces the model to encode different concepts in different partitions

This is fundamentally different from implicit approaches like β-VAE, which rely on regularization to encourage disentanglement.

Advantages Over Existing Methods

Method Explicit Mapping Known Structure Interpretability Performance
β-VAE No No Moderate Good
InfoGAN No No Good Moderate
Our Method Yes Yes Excellent Excellent

Key advantages:

  • Explicit mapping: We know exactly which dimensions encode which features
  • Principled approach: Based on solid theoretical foundations (adversarial projection)
  • Flexible framework: Can be adapted to any number of known causal factors

Limitations and Future Directions

  1. Known causal structure required: Our method assumes you know the independent factors a priori
  2. Training stability: The alternating optimization can be unstable with poor hyperparameter choices
  3. Scalability: Tested only on simple datasets (Colored MNIST)
  4. Independence assumption: Requires truly independent generative factors

While our approach shows promise in controlled settings, several fundamental challenges limit its real-world applicability:

The Superposition Problem

Real neural networks exhibit superposition - the phenomenon where features are represented as linear combinations across many dimensions rather than being cleanly separated. As demonstrated in recent work on mechanistic interpretability, individual neurons often encode multiple concepts simultaneously, and individual concepts are distributed across multiple neurons. This creates several problems for our approach:

  • Feature interference: When concepts naturally superpose, enforcing strict partitioning may damage both concepts
  • Representation efficiency: Neural networks may achieve better compression by allowing controlled feature mixing
  • Emergent representations: Some high-level concepts only emerge through combinations of lower-level features

Unknown Intrinsic Dimensionality

A major practical limitation is that we rarely know the intrinsic dimensionality of real generative factors:

  • How many dimensions does “color” really need? In our Colored MNIST experiment, we assumed color could be encoded in d/2 dimensions, but real color spaces are complex
  • What about hierarchical factors? Object identity might require 50 dimensions, while lighting might need 20, but we don’t know these numbers a priori
  • Interaction effects: Some factors may require additional dimensions when they interact (e.g., how material appearance changes under different lighting)

Example failure case: If we allocate too few dimensions to a complex factor, R-LACE will successfully remove “unwanted” correlations, but the remaining space won’t be sufficient to represent the factor adequately.

The Generative Factor Discovery Problem

Perhaps the most fundamental limitation: we typically don’t know what the true generative factors are:

  • Natural images: What are the independent factors generating a photo? Object identity, pose, lighting, camera parameters, background, weather, time of day…?
  • Language: Syntax, semantics, pragmatics, style, register, emotional content…?
  • Medical data: Disease state, patient demographics, imaging modality, technical factors…?

Our Colored MNIST example benefits from a perfectly controlled environment where we artificially created exactly two independent factors. Real-world data lacks this luxury.

Computational and Scalability Issues

Memory complexity: As the number of factors grows, our approach requires:

  • Separate projection matrices for each factor: O(k × d²/k) = O(d²) space
  • Separate R-LACE optimization for each partition: O(k) computational overhead
  • Joint optimization across all partitions: potentially exponential in the number of factors

Training instability: The alternating optimization can fail when:

  • Factors are not truly independent (most real cases)
  • Latent dimensions are insufficient
  • Multiple factors compete for the same representational space

When the Approach Breaks Down

Consider these realistic scenarios where our method would struggle:

Scenario 1: Medical Imaging



#### Scenario 1: Natural Language
```python
factors = ["syntax", "semantics", "style"]

# Reality:
# - Syntax and semantics are deeply intertwined
# - Style affects both syntax and semantic choices
# - Cultural context influences all factors
# - Individual word meanings depend on context
# - We can't cleanly separate these concepts

Scenario 2: Real-World Images

factors = ["object_identity", "pose", "lighting", "background"]

# Why this fails:
# - Object appearance changes dramatically with pose and lighting
# - Background affects object visibility and interpretation
# - Some objects are defined partly by their typical backgrounds
# - Pose space dimensionality varies dramatically by object type

The Fundamental Trade-off

Our approach reveals a fundamental tension in disentangled representation learning:

Perfect disentanglementRepresentational efficiency

  • Strict partitioning ensures interpretability but may waste representational capacity
  • Allowing superposition enables efficient compression and simulation of higher dimensional networks but loses interpretability
  • Real neural networks appear to prefer efficiency over interpretability

This suggests that the goal of “perfect disentanglement” may be fundamentally at odds with how neural networks naturally want to represent information.

Implementation and Reproducibility

Our complete implementation is available on GitHub, including:

  • Dataset generation: Colored MNIST creation scripts
  • Model architecture: Complete autoencoder with R-LACE integration
  • Training pipeline: Alternating optimization implementation
  • Evaluation metrics: Disentanglement quality assessment
  • Visualization tools: Latent space plotting and reconstruction galleries

Conclusion

Our work introduces a fundamentally new approach to disentangled representation learning that moves beyond implicit regularization to explicit structural constraints. By leveraging adversarial projection to obstruct unwanted correlations, we achieve:

True disentanglement: Independent factors in separate latent dimensions
Explicit mapping: Known correspondence between dimensions and concepts
Principled foundation: Based on solid theoretical understanding
Practical effectiveness: Demonstrated on concrete experimental tasks

Key insight: Sometimes the best way to learn independent representations is to actively fight against entanglement, rather than hoping regularization will encourage it. However, this approach may be fundamentally limited by the reality that natural data often exhibits meaningful entanglement and superposition.

Still, we believe our work represents a step toward more interpretable and trustworthy AI systems. By explicitly partitioning causal mechanisms, we enable:

Compositional Generation

# Mix and match independent factors
digit_encoding = encode_digit(image_of_7)
color_encoding = encode_color(blue_image)
new_image = decode(concat(digit_encoding, color_encoding))
# Result: blue number 7

Robust Domain Transfer

  • Models with disentangled representations should generalize better to new color-digit combinations
  • Less susceptible to spurious correlations in training data

Interpretable Interventions

  • Modify specific attributes without affecting others
  • Enable precise control over generated content
  • Support counterfactual reasoning

Looking Forward

As AI systems become more complex and consequential, the ability to understand and control their internal representations becomes increasingly critical. Our projection-based approach provides a concrete step toward building AI systems that are not just powerful, but interpretable, controllable, and trustworthy.

However, the practical limitations we’ve identified suggest that the path forward may require fundamentally different approaches. Rather than enforcing perfect disentanglement, future work might focus on controllable entanglement - systems that can flexibly adjust the degree of factor separation based on the task at hand.

The marriage of causal structure knowledge with end-to-end neural optimization opens exciting possibilities, but the challenges of superposition, unknown dimensionality, and factor discovery remind us that the goal of truly disentangled AI remains an active area of research with significant unsolved problems.


This research was conducted at Brown University’s Department of Computer Science with collaborator Neil Xu. The complete codebase, experimental data, and additional visualizations are available in our GitHub repository.

Technical Appendix

R-LACE Implementation Details

Core Algorithm:

def rlace(X, y, rank=1, max_iter=100):
    # Initialize with Spectral Attribute Removal
    P = sal_initialization(X, y, rank)
    
    for iteration in range(max_iter):
        # Train classifier on projected data
        clf.fit(X @ (I - P), y)
        
        # Update projection to maximize classifier loss
        P = optimize_projection(P, X, y, clf)
        
        if convergence_check(P, clf_loss):
            break
    
    return orthogonalize(P)

Convergence Criteria:

  • Gradient norm threshold: 1e-2
  • Loss improvement tolerance: 1e-2
  • Maximum iterations: 100

Evaluation Metrics

Disentanglement Score:

DS = (Acc_corresponding - Acc_opposite) / Acc_corresponding

Where:

  • Acc_corresponding: Accuracy predicting correct factor from its partition
  • Acc_opposite: Accuracy predicting wrong factor from partition

Perfect disentanglement: DS = 1.0
No disentanglement: DS = 0.0

References

Please see the writeup for the full references.