Source code for iVAE.model

"""
iVAE model implementation with multiple loss functions.

This module implements the core iVAE model that combines a VAE architecture
with various regularization techniques including KL divergence, DIP (Disentangled
Inferred Prior), Beta-TC VAE, and InfoVAE losses.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from .mixin import scviMixin, dipMixin, betatcMixin, infoMixin
from .module import VAE

[docs] class iVAE(scviMixin, dipMixin, betatcMixin, infoMixin): """ Interpretable Variational Autoencoder model for single-cell data. This class combines multiple loss functions and regularization techniques to learn interpretable latent representations of single-cell gene expression data. It uses a negative binomial distribution for reconstruction and supports various disentanglement objectives. Parameters ---------- irecon : float Weight for the interpretative reconstruction loss. If > 0, penalizes the difference between reconstruction from interpretative and original latent codes. beta : float Weight for the KL divergence term. Higher values encourage latent codes to match the prior distribution (standard normal). dip : float Weight for the DIP (Disentangled Inferred Prior) loss. Encourages diagonal covariance structure in the latent space. tc : float Weight for the total correlation (TC) term from Beta-TC VAE. Encourages factorized latent representations. info : float Weight for the InfoVAE MMD (Maximum Mean Discrepancy) loss. Encourages the latent distribution to match the prior. state_dim : int Dimension of the input/output state (number of genes). hidden_dim : int Dimension of the hidden layers. latent_dim : int Dimension of the main latent space. i_dim : int Dimension of the interpretative latent space. lr : float Learning rate for the Adam optimizer. device : torch.device Device to run computations on (CPU or CUDA). Attributes ---------- nn : VAE The VAE neural network model. nn_optimizer : torch.optim.Adam Optimizer for training the VAE. loss : list Training history storing loss components for each update. """
[docs] def __init__( self, irecon, beta, dip, tc, info, state_dim, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs ): self.irecon = irecon self.beta = beta self.dip = dip self.tc = tc self.info = info self.nn = VAE(state_dim, hidden_dim, latent_dim, i_dim).to(device) self.nn_optimizer = optim.Adam(self.nn.parameters(), lr=lr) self.device = device self.loss = []
def take_latent( self, state ): """ Extract latent representations from input data. Parameters ---------- state : numpy.ndarray or torch.Tensor Input gene expression data of shape (n_samples, n_genes). Returns ------- numpy.ndarray Latent representations of shape (n_samples, latent_dim). """ state = torch.tensor(state, dtype=torch.float).to(self.device) q_z, _, _, _, _, _, _ = self.nn(state) return q_z.detach().cpu().numpy() def update( self, states ): """ Perform one training step with the given batch of data. This method computes the total loss (reconstruction + regularization terms) and updates the model parameters via backpropagation. Parameters ---------- states : numpy.ndarray or torch.Tensor Batch of gene expression data of shape (batch_size, n_genes). Notes ----- The total loss is composed of: - Negative binomial reconstruction loss - Optional interpretative reconstruction loss (if irecon > 0) - KL divergence (weighted by beta) - Optional DIP loss (if dip > 0) - Optional TC loss (if tc > 0) - Optional MMD loss (if info > 0) """ states = torch.tensor(states, dtype=torch.float).to(self.device) q_z, q_m, q_s, pred_x, le, ld, pred_xl = self.nn(states) # Scale predictions by library size (total count per cell) l = states.sum(-1).view(-1,1) pred_x = pred_x * l # Compute negative binomial reconstruction loss disp = torch.exp(self.nn.decoder.disp) recon_loss = -self._log_nb(states, pred_x, disp).sum(-1).mean() # Compute interpretative reconstruction loss if enabled if self.irecon: pred_xl = pred_xl * l irecon_loss = - self.irecon * self._log_nb(states, pred_xl, disp).sum(-1).mean() else: irecon_loss = torch.zeros(1).to(self.device) # Compute KL divergence from standard normal prior p_m = torch.zeros_like(q_m) p_s = torch.zeros_like(q_s) kl_div = self.beta * self._normal_kl(q_m, q_s, p_m, p_s).sum(-1).mean() # Compute DIP loss if enabled (encourages diagonal covariance) if self.dip: dip_loss = self.dip * self._dip_loss(q_m ,q_s) else: dip_loss = torch.zeros(1).to(self.device) # Compute total correlation loss if enabled (encourages factorization) if self.tc: tc_loss = self.tc * self._betatc_compute_total_correlation(q_z, q_m ,q_s) else: tc_loss = torch.zeros(1).to(self.device) # Compute MMD loss if enabled (matches prior distribution) if self.info: mmd_loss = self.info * self._compute_mmd(q_z, torch.randn_like(q_z)) else: mmd_loss = torch.zeros(1).to(self.device) # Combine all loss terms total_loss = recon_loss + irecon_loss + kl_div + dip_loss + tc_loss + mmd_loss # Backpropagation and optimization step self.nn_optimizer.zero_grad() total_loss.backward() self.nn_optimizer.step() # Record loss components for monitoring self.loss.append(( total_loss.item(), recon_loss.item(), irecon_loss.item(), kl_div.item(), dip_loss.item(), tc_loss.item(), mmd_loss.item() ))