Source code for iVAE.environment

"""
Environment class for managing data and training workflow.

This module provides the Env class that manages the training environment,
including data loading, batching, and evaluation scoring.
"""

from .model import iVAE
from .mixin import envMixin
import numpy as np
from sklearn.cluster import KMeans


[docs] class Env(iVAE, envMixin): """ Training environment for the iVAE model. This class manages the training workflow including data preprocessing, batch sampling, model updates, and performance evaluation. It inherits from both iVAE (the model) and envMixin (evaluation utilities). Parameters ---------- adata : anndata.AnnData Annotated data matrix containing single-cell gene expression data. layer : str Name of the layer in adata.layers to use for training. percent : float Proportion of data to use in each training batch (0 < percent <= 1). irecon : float Weight for interpretative reconstruction loss. beta : float Weight for KL divergence term. dip : float Weight for DIP loss. tc : float Weight for total correlation loss. info : float Weight for InfoVAE MMD loss. hidden_dim : int Dimension of hidden layers in the neural network. latent_dim : int Dimension of the main latent space. i_dim : int Dimension of the interpretative latent space. lr : float Learning rate for optimization. device : torch.device Device for computation (CPU or CUDA). Attributes ---------- X : numpy.ndarray Log-normalized gene expression matrix of shape (n_cells, n_genes). n_obs : int Number of cells (observations). n_var : int Number of genes (variables). batch_size : int Number of samples per training batch. labels : numpy.ndarray Reference cluster labels for evaluation (from K-means on input data). score : list Training history of evaluation scores. """
[docs] def __init__( self, adata, layer, percent, irecon, beta, dip, tc, info, hidden_dim, latent_dim, i_dim, lr, device, *args, **kwargs ): self._register_anndata(adata, layer, latent_dim) self.batch_size = int(percent * self.n_obs) super().__init__( irecon = irecon, beta = beta, dip = dip, tc = tc, info = info, state_dim = self.n_var, hidden_dim = hidden_dim, latent_dim = latent_dim, i_dim = i_dim, lr = lr, device = device ) self.score = []
def load_data( self, ): """ Load a random batch of data for training. Returns ------- numpy.ndarray A batch of gene expression data of shape (batch_size, n_genes). """ data, idx = self._sample_data() self.idx = idx return data def step( self, data ): """ Perform one training step: update model and evaluate performance. Parameters ---------- data : numpy.ndarray Batch of gene expression data. """ self.update(data) latent = self.take_latent(data) score = self._calc_score(latent) self.score.append(score) def _sample_data( self, ): """ Sample a random batch from the dataset. Returns ------- data : numpy.ndarray Batch of gene expression data. idx_ : numpy.ndarray Indices of the sampled cells. """ idx = np.random.permutation(self.n_obs) idx_ = np.random.choice(idx, self.batch_size) data = self.X[idx_,:] return data, idx_ def _register_anndata( self, adata, layer: str, latent_dim ): """ Register and preprocess AnnData object. This method extracts the specified layer from the AnnData object, applies log1p transformation, and creates reference labels for evaluation. Parameters ---------- adata : anndata.AnnData Annotated data matrix. layer : str Name of the layer to use (must exist in adata.layers). latent_dim : int Number of clusters to create for reference labels. """ self.X = np.log1p(adata.layers[layer].A) self.n_obs = adata.shape[0] self.n_var = adata.shape[1] self.labels = KMeans(latent_dim).fit_predict(self.X) return