"""
Utility functions for iVAE analysis and evaluation.
This module provides utility functions for processing results, computing
evaluation metrics, and analyzing latent representations from iVAE models.
"""
import numpy as np
from numpy import ndarray
import pandas as pd
import scib
from sklearn.cluster import KMeans
from sklearn.neighbors import kneighbors_graph
from sklearn.metrics import adjusted_mutual_info_score, normalized_mutual_info_score, silhouette_score, calinski_harabasz_score, davies_bouldin_score
from scipy.sparse import csr_matrix
from scipy.sparse import csgraph
[docs]
def get_dfs(
mode,
agent_list
):
"""
Aggregate and summarize scores from multiple agent runs.
Parameters
----------
mode : str
Aggregation mode, either 'mean' or 'std'.
agent_list : list
List of trained agent objects with score histories.
Returns
-------
generator
Generator yielding DataFrames with aggregated scores.
Columns are: ARI, NMI, ASW, C_H, D_B, P_C.
"""
if mode == 'mean':
ls = list(map(lambda x: zip(*(np.array(b).mean(axis=0) for b in zip(*((zip(*a.score)) for a in x)))), list(zip(*agent_list))))
else:
ls = list(map(lambda x: zip(*(np.array(b).std(axis=0) for b in zip(*((zip(*a.score)) for a in x)))), list(zip(*agent_list))))
return (map(lambda x:pd.DataFrame(x, columns=['ARI', 'NMI', 'ASW', 'C_H', 'D_B', 'P_C']),ls))
[docs]
def moving_average(
a,
window_size
):
"""
Compute moving average with boundary handling.
This function computes a moving average while properly handling boundary
conditions at the start and end of the array.
Parameters
----------
a : numpy.ndarray
Input array to smooth.
window_size : int
Size of the moving average window.
Returns
-------
numpy.ndarray
Smoothed array of the same length as input.
"""
cumulative_sum = np.cumsum(np.insert(a, 0, 0))
middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
r = np.arange(1, window_size-1, 2)
begin = np.cumsum(a[:window_size-1])[::2] / r
end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
return np.concatenate((begin, middle, end))
[docs]
def fetch_score(
adata,
latent,
label_true,
label_mode='KMeans',
batch=False
):
"""
Compute comprehensive evaluation metrics for latent representations.
This function evaluates the quality of latent representations by computing
clustering metrics, graph connectivity, and batch integration metrics.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix (will be modified in place to add latent embeddings).
latent : numpy.ndarray
Latent representations of shape (n_cells, latent_dim).
label_true : array-like
True cluster labels for cells.
label_mode : str, optional
Method for assigning labels from latent space:
- 'KMeans': Apply K-means clustering (default)
- 'Max': Use argmax of latent dimensions
- 'Min': Use argmin of latent dimensions
batch : bool, optional
If True, compute batch integration metrics (requires 'batch' in adata.obs).
Default is False.
Returns
-------
tuple
If batch=False: (NMI, ARI, ASW, C_H, D_B, G_C, clisi)
If batch=True: (NMI, ARI, ASW, C_H, D_B, G_C, clisi, ilisi, bASW)
where:
- NMI: Normalized Mutual Information
- ARI: Adjusted Rand Index
- ASW: Average Silhouette Width
- C_H: Calinski-Harabasz score
- D_B: Davies-Bouldin score
- G_C: Graph connectivity
- clisi: Cell-type Local Inverse Simpson's Index
- ilisi: Batch Local Inverse Simpson's Index (batch integration)
- bASW: Batch Average Silhouette Width (batch integration)
"""
q_z = latent
if label_mode == 'KMeans':
labels = KMeans(q_z.shape[1]).fit_predict(q_z)
elif label_mode == 'Max':
labels = np.argmax(q_z, axis=1)
elif label_mode == 'Min':
labels = np.argmin(q_z, axis=1)
else:
raise ValueError('Mode must be in one of KMeans, Max and Min')
adata.obsm['X_qz'] = q_z
adata.obs['label'] = pd.Categorical(labels)
NMI = normalized_mutual_info_score(label_true, labels)
ARI = adjusted_mutual_info_score(label_true, labels)
ASW = silhouette_score(q_z, labels)
if label_mode != 'KMeans':
ASW = abs(ASW)
C_H = calinski_harabasz_score(q_z, labels)
D_B = davies_bouldin_score(q_z, labels)
G_C = graph_connection(kneighbors_graph(adata.obsm['X_qz'], 15), adata.obs['label'].values)
clisi = scib.metrics.clisi_graph(adata, 'label', 'embed', 'X_qz', n_cores=26)
if batch:
sub_adata = adata[np.random.choice(adata.obs_names, 5000, replace=False)].copy()
ilisi = scib.metrics.ilisi_graph(sub_adata, 'batch', 'embed', 'X_qz', n_cores=26)
bASW = scib.metrics.silhouette_batch(sub_adata, 'batch', 'label', 'X_qz')
print('Completed')
return NMI, ARI, ASW, C_H, D_B, G_C, clisi, ilisi, bASW
print('Completed')
return NMI, ARI, ASW, C_H, D_B, G_C, clisi
[docs]
def graph_connection(
graph: csr_matrix,
labels: ndarray
):
"""
Compute graph connectivity score for each cluster.
This metric measures how well connected cells within each cluster are
in the k-nearest neighbor graph. Higher values indicate better connectivity.
Parameters
----------
graph : scipy.sparse.csr_matrix
Sparse adjacency matrix representing the k-nearest neighbor graph.
labels : numpy.ndarray
Cluster labels for each cell.
Returns
-------
float
Average connectivity score across all clusters (range: 0 to 1).
Higher values indicate cells in the same cluster are well-connected.
"""
cg_res = []
for l in np.unique(labels):
mask = np.where(labels==l)[0]
subgraph = graph[mask, :][:, mask]
_, lab = csgraph.connected_components(subgraph, connection='strong')
tab = np.unique(lab, return_counts=True)[1]
cg_res.append(tab.max() / tab.sum())
return np.mean(cg_res)