"""
Mixin classes providing various loss functions and utilities for iVAE.
This module contains mixin classes that implement different loss functions and
regularization techniques used in the iVAE model, as well as environment-related
utilities for data handling and evaluation.
"""
import torch
import numpy as np
from sklearn.neighbors import kneighbors_graph
from sklearn.cluster import KMeans
from sklearn.preprocessing import minmax_scale
from sklearn.metrics import adjusted_mutual_info_score, normalized_mutual_info_score, silhouette_score, calinski_harabasz_score, davies_bouldin_score
[docs]
class scviMixin:
"""
Mixin providing scVI-style loss functions.
This mixin implements the KL divergence and negative binomial log-likelihood
functions used in the scVI model for single-cell data.
"""
def _normal_kl(
self,
mu1,
lv1,
mu2,
lv2
):
"""
Compute KL divergence between two Gaussian distributions.
Parameters
----------
mu1 : torch.Tensor
Mean of the first Gaussian distribution.
lv1 : torch.Tensor
Log-variance of the first Gaussian distribution.
mu2 : torch.Tensor
Mean of the second Gaussian distribution.
lv2 : torch.Tensor
Log-variance of the second Gaussian distribution.
Returns
-------
torch.Tensor
KL divergence KL(N(mu1, exp(lv1)) || N(mu2, exp(lv2))).
"""
v1 = torch.exp(lv1)
v2 = torch.exp(lv2)
lstd1 = lv1 / 2.
lstd2 = lv2 / 2.
kl = lstd2 - lstd1 + (v1 + (mu1 - mu2)**2.) / (2. * v2) - 0.5
return kl
def _log_nb(
self,
x,
mu,
theta,
eps=1e-8
):
"""
Compute log-likelihood of negative binomial distribution.
The negative binomial distribution is commonly used for count data like
gene expression, as it can model over-dispersion.
Parameters
----------
x : torch.Tensor
Observed count data.
mu : torch.Tensor
Mean parameter of the negative binomial distribution.
theta : torch.Tensor
Dispersion parameter of the negative binomial distribution.
Higher values indicate less over-dispersion.
eps : float, optional
Small constant for numerical stability, by default 1e-8.
Returns
-------
torch.Tensor
Log-likelihood values for each observation.
"""
log_theta_mu_eps = torch.log(theta + mu + eps)
res = (
theta * (torch.log(theta + eps) - log_theta_mu_eps)
+ x * (torch.log(mu + eps) - log_theta_mu_eps)
+ torch.lgamma(x + theta)
- torch.lgamma(theta)
- torch.lgamma(x + 1)
)
return res
[docs]
class betatcMixin:
"""
Mixin providing Beta-TC VAE loss functions.
This mixin implements the total correlation (TC) decomposition from the
Beta-TC VAE paper, which encourages factorized latent representations.
"""
def _betatc_compute_gaussian_log_density(
self,
samples,
mean,
log_var
):
"""
Compute log density of Gaussian distribution.
Parameters
----------
samples : torch.Tensor
Sampled points from the distribution.
mean : torch.Tensor
Mean of the Gaussian distribution.
log_var : torch.Tensor
Log-variance of the Gaussian distribution.
Returns
-------
torch.Tensor
Log density values.
"""
import math
pi = torch.tensor(math.pi, requires_grad=False)
normalization = torch.log(2 * pi)
inv_sigma = torch.exp(-log_var)
tmp = samples - mean
return -0.5 * (tmp * tmp * inv_sigma + log_var + normalization)
def _betatc_compute_total_correlation(
self,
z_sampled,
z_mean,
z_logvar
):
"""
Compute the total correlation (TC) term.
Total correlation measures the mutual information between latent dimensions,
quantifying how much they deviate from being independent.
Parameters
----------
z_sampled : torch.Tensor
Sampled latent codes of shape (batch_size, latent_dim).
z_mean : torch.Tensor
Mean of latent distribution of shape (batch_size, latent_dim).
z_logvar : torch.Tensor
Log-variance of latent distribution of shape (batch_size, latent_dim).
Returns
-------
torch.Tensor
Total correlation value (scalar).
"""
log_qz_prob = self._betatc_compute_gaussian_log_density(
z_sampled.unsqueeze(dim=1),
z_mean.unsqueeze(dim=0),
z_logvar.unsqueeze(dim=0)
)
log_qz_product = log_qz_prob.exp().sum(dim=1).log().sum(dim=1)
log_qz = log_qz_prob.sum(dim=2).exp().sum(dim=1).log()
return (log_qz - log_qz_product).mean()
[docs]
class infoMixin:
"""
Mixin providing InfoVAE loss functions.
This mixin implements Maximum Mean Discrepancy (MMD) from InfoVAE, which
matches the aggregated posterior to the prior distribution.
"""
def _compute_mmd(
self,
z_posterior_samples,
z_prior_samples
):
"""
Compute Maximum Mean Discrepancy (MMD) between posterior and prior.
MMD is a kernel-based distance metric between two distributions. It's used
to encourage the aggregated posterior to match the prior distribution.
Parameters
----------
z_posterior_samples : torch.Tensor
Samples from the posterior distribution.
z_prior_samples : torch.Tensor
Samples from the prior distribution.
Returns
-------
torch.Tensor
MMD value (scalar).
"""
mean_pz_pz = self._compute_unbiased_mean(self._compute_kernel(z_prior_samples, z_prior_samples), unbiased=True)
mean_pz_qz = self._compute_unbiased_mean(self._compute_kernel(z_prior_samples, z_posterior_samples), unbiased=False)
mean_qz_qz = self._compute_unbiased_mean(self._compute_kernel(z_posterior_samples, z_posterior_samples), unbiased=True)
mmd = mean_pz_pz - 2*mean_pz_qz + mean_qz_qz
return mmd
def _compute_unbiased_mean(
self,
kernel,
unbiased
):
"""
Compute (unbiased) mean of kernel matrix.
Parameters
----------
kernel : torch.Tensor
Kernel matrix.
unbiased : bool
If True, exclude diagonal elements to get unbiased estimator.
Returns
-------
torch.Tensor
Mean kernel value.
"""
N, M = kernel.shape
if unbiased:
sum_kernel = kernel.sum(dim=(0, 1)) - torch.diagonal(kernel, dim1=0, dim2=1).sum(dim=-1)
mean_kernel = sum_kernel / (N*(N-1))
else:
mean_kernel = kernel.mean(dim=(0, 1))
return mean_kernel
def _compute_kernel(
self,
z0,
z1
):
"""
Compute RBF kernel matrix between two sets of samples.
Parameters
----------
z0 : torch.Tensor
First set of samples of shape (batch_size, z_dim).
z1 : torch.Tensor
Second set of samples of shape (batch_size, z_dim).
Returns
-------
torch.Tensor
Kernel matrix of shape (batch_size, batch_size).
"""
batch_size, z_size = z0.shape
z0 = z0.unsqueeze(-2)
z1 = z1.unsqueeze(-3)
z0 = z0.expand(batch_size, batch_size, z_size)
z1 = z1.expand(batch_size, batch_size, z_size)
kernel = self._kernel_rbf(z0, z1)
return kernel
def _kernel_rbf(
self,
x,
y
):
"""
Compute Radial Basis Function (RBF) kernel.
Parameters
----------
x : torch.Tensor
First set of points.
y : torch.Tensor
Second set of points.
Returns
-------
torch.Tensor
RBF kernel values.
"""
z_size = x.shape[-1]
sigma = 2 * 2 * z_size
kernel = torch.exp(-((x - y).pow(2).sum(dim=-1) / sigma))
return kernel
[docs]
class dipMixin:
"""
Mixin providing DIP (Disentangled Inferred Prior) loss functions.
This mixin implements the DIP loss that encourages the covariance matrix
of the latent distribution to be diagonal with unit variance.
"""
def _dip_loss(
self,
q_m,
q_s
):
"""
Compute DIP loss for disentanglement.
DIP loss penalizes:
1. Diagonal elements deviating from 1 (unit variance)
2. Off-diagonal elements deviating from 0 (independence)
Parameters
----------
q_m : torch.Tensor
Mean of latent distribution of shape (batch_size, latent_dim).
q_s : torch.Tensor
Log-variance of latent distribution of shape (batch_size, latent_dim).
Returns
-------
torch.Tensor
DIP loss value (scalar).
"""
cov_matrix = self._dip_cov_matrix(q_m, q_s)
cov_diag = torch.diagonal(cov_matrix)
cov_off_diag = cov_matrix - torch.diag(cov_diag)
dip_loss_d = torch.sum((cov_diag - 1)**2) # Penalize deviation from unit variance
dip_loss_od = torch.sum(cov_off_diag**2) # Penalize non-zero correlations
dip_loss = 10 * dip_loss_d + 5 * dip_loss_od
return dip_loss
def _dip_cov_matrix(
self,
q_m,
q_s
):
"""
Compute covariance matrix of the latent distribution.
Parameters
----------
q_m : torch.Tensor
Mean of latent distribution.
q_s : torch.Tensor
Log-variance of latent distribution.
Returns
-------
torch.Tensor
Covariance matrix.
"""
cov_q_mean = torch.cov(q_m.T)
E_var = torch.mean(torch.diag(q_s.exp()), dim=0)
cov_matrix = cov_q_mean + E_var
return cov_matrix
[docs]
class envMixin:
"""
Mixin providing environment utilities for data handling and evaluation.
This mixin implements functions for clustering, scoring, and correlation
analysis of latent representations.
"""
def _calc_score(
self,
latent
):
"""
Calculate clustering performance scores for latent representations.
Parameters
----------
latent : numpy.ndarray
Latent representations of shape (n_samples, latent_dim).
Returns
-------
tuple
A tuple of scores: (ARI, NMI, ASW, C_H, D_B, P_C).
"""
n = latent.shape[1]
labels = self._calc_label(latent)
scores = self._metrics(latent, labels)
return scores
def _calc_label(
self,
latent
):
"""
Perform K-means clustering on latent representations.
Parameters
----------
latent : numpy.ndarray
Latent representations.
Returns
-------
numpy.ndarray
Cluster labels.
"""
labels = KMeans(latent.shape[1]).fit_predict(latent)
return labels
def _calc_corr(
self,
latent
):
"""
Calculate average absolute correlation between latent dimensions.
This metric quantifies how correlated the latent dimensions are, which
is desirable in iVAE as it indicates the model has learned meaningful
relationships between features.
Parameters
----------
latent : numpy.ndarray
Latent representations.
Returns
-------
float
Average absolute correlation (excluding self-correlation).
"""
acorr = abs(np.corrcoef(latent.T))
return acorr.sum(axis=1).mean().item() - 1
def _metrics(
self,
latent,
labels
):
"""
Compute multiple clustering evaluation metrics.
Parameters
----------
latent : numpy.ndarray
Latent representations.
labels : numpy.ndarray
Predicted cluster labels.
Returns
-------
tuple
Metrics: (ARI, NMI, ASW, C_H, D_B, P_C) where:
- ARI: Adjusted Rand Index (agreement with true labels)
- NMI: Normalized Mutual Information (agreement with true labels)
- ASW: Average Silhouette Width (cluster cohesion)
- C_H: Calinski-Harabasz score (cluster separation)
- D_B: Davies-Bouldin score (cluster compactness, lower is better)
- P_C: Pearson Correlation (average correlation between dimensions)
"""
ARI = adjusted_mutual_info_score(self.labels[self.idx], labels)
NMI = normalized_mutual_info_score(self.labels[self.idx], labels)
ASW = silhouette_score(latent, labels)
C_H = calinski_harabasz_score(latent, labels)
D_B = davies_bouldin_score(latent, labels)
P_C = self._calc_corr(latent)
return ARI, NMI, ASW, C_H, D_B, P_C