iVAE.model.iVAE

class iVAE.model.iVAE(irecon, beta, dip, tc, info, state_dim, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs)[source]

Interpretable Variational Autoencoder model for single-cell data.

This class combines multiple loss functions and regularization techniques to learn interpretable latent representations of single-cell gene expression data. It uses a negative binomial distribution for reconstruction and supports various disentanglement objectives.

Parameters:
  • irecon (float) – Weight for the interpretative reconstruction loss. If > 0, penalizes the difference between reconstruction from interpretative and original latent codes.

  • beta (float) – Weight for the KL divergence term. Higher values encourage latent codes to match the prior distribution (standard normal).

  • dip (float) – Weight for the DIP (Disentangled Inferred Prior) loss. Encourages diagonal covariance structure in the latent space.

  • tc (float) – Weight for the total correlation (TC) term from Beta-TC VAE. Encourages factorized latent representations.

  • info (float) – Weight for the InfoVAE MMD (Maximum Mean Discrepancy) loss. Encourages the latent distribution to match the prior.

  • state_dim (int) – Dimension of the input/output state (number of genes).

  • hidden_dim (int) – Dimension of the hidden layers.

  • latent_dim (int) – Dimension of the main latent space.

  • i_dim (int) – Dimension of the interpretative latent space.

  • lr (float) – Learning rate for the Adam optimizer.

  • device (torch.device) – Device to run computations on (CPU or CUDA).

nn

The VAE neural network model.

Type:

VAE

nn_optimizer

Optimizer for training the VAE.

Type:

torch.optim.Adam

loss

Training history storing loss components for each update.

Type:

list

__init__(irecon, beta, dip, tc, info, state_dim, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs)[source]

Methods

__init__(irecon, beta, dip, tc, info, ...)

take_latent(state)

Extract latent representations from input data.

update(states)

Perform one training step with the given batch of data.