iVAE.agent.agent
- class iVAE.agent.agent(adata: AnnData, layer: str = 'counts', percent: float = 0.01, irecon: float = 0.0, beta: float = 1.0, dip: float = 0.0, tc: float = 0.0, info: float = 0.0, hidden_dim: int = 128, latent_dim: int = 10, i_dim: int = 2, lr: float = 0.0001, device: device = device(type='cpu'))[source]
High-level interface for training and using iVAE models.
The agent class provides a user-friendly interface for training interpretable Variational Autoencoders (iVAE) on single-cell transcriptomics data. It handles data preprocessing, model training, and extraction of learned representations.
iVAE enhances standard VAE by incorporating an interpretative module that increases correlation between latent components, helping capture biologically meaningful gene expression patterns in single-cell data.
- Parameters:
adata (AnnData) – Annotated data matrix containing single-cell gene expression data. Should have at least one layer (e.g., ‘counts’, ‘X’) with raw or normalized counts.
layer (str, optional) – The layer of the AnnData object to use, by default ‘counts’. Common options: ‘counts’, ‘X’, or custom layer names.
percent (float, optional) – Fraction of cells to use per training batch (0 < percent <= 1), by default 0.01. Smaller values = smaller batches, more frequent updates. Larger values = more stable gradients.
irecon (float, optional) – Weight for interpretative reconstruction loss, by default 0.0. If > 0, penalizes reconstruction errors from the interpretative bottleneck, encouraging more interpretable latent representations.
beta (float, optional) – Weight for KL divergence term (beta-VAE), by default 1.0. Higher values encourage latent codes closer to prior (standard normal), potentially improving disentanglement but may reduce reconstruction quality.
dip (float, optional) – Weight for DIP (Disentangled Inferred Prior) loss, by default 0.0. If > 0, encourages diagonal covariance in latent space for disentanglement.
tc (float, optional) – Weight for Total Correlation (TC) loss from Beta-TC VAE, by default 0.0. If > 0, encourages factorized (independent) latent dimensions.
info (float, optional) – Weight for InfoVAE MMD (Maximum Mean Discrepancy) loss, by default 0.0. If > 0, matches aggregated posterior to prior using kernel-based distance.
hidden_dim (int, optional) – Dimension of hidden layers in encoder/decoder networks, by default 128. Larger values increase model capacity but require more data and computation.
latent_dim (int, optional) – Dimension of the main latent space, by default 10. Should roughly match the expected number of cell types or states.
i_dim (int, optional) – Dimension of the interpretative latent space (bottleneck), by default 2. This compressed representation encourages learning of correlated patterns. Should be smaller than latent_dim.
lr (float, optional) – Learning rate for Adam optimizer, by default 1e-4.
device (torch.device, optional) – Device to run computations on, by default uses GPU if available, otherwise CPU.
- fit(epochs=1000)[source]
Train the model on the data for a specified number of epochs. Returns the trained agent instance.
- get_iembed()[source]
Extract the interpretative embedding (intermediate bottleneck representation). Returns a NumPy array of shape (n_cells, i_dim).
- get_latent()[source]
Extract the main latent representation. Returns a NumPy array of shape (n_cells, latent_dim).
Examples
Basic usage with default parameters:
>>> import scanpy as sc >>> from iVAE import agent >>> >>> # Load single-cell data >>> adata = sc.read_h5ad('data.h5ad') >>> >>> # Train iVAE model >>> ag = agent(adata, layer='counts', latent_dim=10) >>> ag.fit(epochs=500) >>> >>> # Extract representations >>> latent = ag.get_latent() >>> iembed = ag.get_iembed()
With custom regularization:
>>> # Train with interpretative reconstruction and disentanglement >>> ag = agent( ... adata, ... layer='counts', ... latent_dim=10, ... i_dim=3, ... irecon=0.5, # Enable interpretative reconstruction ... beta=2.0, # Stronger KL regularization ... dip=1.0 # Enable disentanglement ... ) >>> ag.fit(epochs=1000)
- __init__(adata: AnnData, layer: str = 'counts', percent: float = 0.01, irecon: float = 0.0, beta: float = 1.0, dip: float = 0.0, tc: float = 0.0, info: float = 0.0, hidden_dim: int = 128, latent_dim: int = 10, i_dim: int = 2, lr: float = 0.0001, device: device = device(type='cpu'))[source]
Methods
__init__(adata[, layer, percent, irecon, ...])fit([epochs])Train the iVAE model on the data.
Extract the interpretative embedding from the trained model.
Extract the main latent representation from the trained model.
load_data()Load a random batch of data for training.
step(data)Perform one training step: update model and evaluate performance.
take_latent(state)Extract latent representations from input data.
update(states)Perform one training step with the given batch of data.