iVAE.model.iVAE
- class iVAE.model.iVAE(irecon, beta, dip, tc, info, state_dim, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs)[source]
Interpretable Variational Autoencoder model for single-cell data.
This class combines multiple loss functions and regularization techniques to learn interpretable latent representations of single-cell gene expression data. It uses a negative binomial distribution for reconstruction and supports various disentanglement objectives.
- Parameters:
irecon (float) – Weight for the interpretative reconstruction loss. If > 0, penalizes the difference between reconstruction from interpretative and original latent codes.
beta (float) – Weight for the KL divergence term. Higher values encourage latent codes to match the prior distribution (standard normal).
dip (float) – Weight for the DIP (Disentangled Inferred Prior) loss. Encourages diagonal covariance structure in the latent space.
tc (float) – Weight for the total correlation (TC) term from Beta-TC VAE. Encourages factorized latent representations.
info (float) – Weight for the InfoVAE MMD (Maximum Mean Discrepancy) loss. Encourages the latent distribution to match the prior.
state_dim (int) – Dimension of the input/output state (number of genes).
hidden_dim (int) – Dimension of the hidden layers.
latent_dim (int) – Dimension of the main latent space.
i_dim (int) – Dimension of the interpretative latent space.
lr (float) – Learning rate for the Adam optimizer.
device (torch.device) – Device to run computations on (CPU or CUDA).
- nn_optimizer
Optimizer for training the VAE.
- Type:
torch.optim.Adam
- loss
Training history storing loss components for each update.
- Type:
list
- __init__(irecon, beta, dip, tc, info, state_dim, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs)[source]
Methods
__init__(irecon, beta, dip, tc, info, ...)take_latent(state)Extract latent representations from input data.
update(states)Perform one training step with the given batch of data.