MILK tutorial: Quantifying dynamic cellular information with MILK trees (Paul15 dataset)

Overview

Hierarchical encoding of datasets provided by MILK trees can inherently capture dynamic aspects of individual cells. That is, we have come to appreciate that cells lie on a continuum with respect to their underlying gene expression profiles, and identification of such an “ordering” of cells can provide insights allowing us to try and explain processes like single-cell developmental trajectories. As such, this tutorial will provide a walkthrough of applying MILK to the transcriptomic profiling of myeloid progenitor cells as they undergo lineage commitment.

The following analyses are included:

  1. Quantify the (weighted) average tree depth at which each of the annotated cell types branch off in the MILK tree.

  2. Calculate the number of single-cell branching events with respect to each cell type in the MILK tree.

  3. Calculate the average clade size required to robustly capture the cell type of interest in the MILK tree according to a threshold (summarized by area under the curve).

These analyses aim to provide simple examples of how cell type-specific branching patterns emerge from the global MILK hierarchy.

Citation:

Paul F, Arkin Y, Giladi A, Jaitin DA, Kenigsberg E, Keren-Shaul H, Winter D, Lara-Astiaso D, Gury M, Weiner A, David E, Cohen N, Lauridsen FK, Haas S, Schlitzer A, Mildner A, Ginhoux F, Jung S, Trumpp A, Porse BT, Tanay A, Amit I. Transcriptional Heterogeneity and Lineage Commitment in Myeloid Progenitors. Cell. 2015 Dec 17;163(7):1663-77. doi: 10.1016/j.cell.2015.11.013. Epub 2015 Nov 25. Erratum in: Cell. 2016 Jan 14;164(1-2):325. PMID: 26627738.

Importing libraries

import os
import pathlib
import pandas as pd
import scanpy as sc

Loading the dataset

The dataset was downloaded through Scanpy as follows:

adata = sc.datasets.paul15()

This loads the dataset as an anndata object with 2730 cells and 3451 genes. It also includes cluster (i.e., cell type) annotations via adata.obs['paul15_clusters'] and root information used in their trajectory inference analyses (adata.uns['iroot']).

adata.obs.head()
paul15_clusters
W31105 7MEP
W31106 15Mo
W31107 3Ery
W31108 15Mo
W31109 3Ery

Standard pre-processing

sc.pp.filter_cells(adata,min_genes=100)
sc.pp.filter_genes(adata,min_cells=3)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
# only considering highly variable genes
sc.pp.highly_variable_genes(adata)
adata = adata[:,adata.var.highly_variable].copy()

This processing resulted in a dataset with 2730 cells and 657 (highly variable) genes.


Initially, we can just plot the UMAP of the dataset to get an overview.

sc.tl.pca(adata)
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50)
sc.tl.umap(adata)
sc.pl.umap(adata,color="paul15_clusters")

png

MILK input

We can apply MILK on the PCA embeddings (50 components) of the processed dataset. The data is written as a .CSV file with no header, where the first column corresponds to cell ID.

input_df = pd.DataFrame(adata.obsm["X_pca"],index=adata.obs_names)
input_df.to_csv("input.csv",header=False)

The table below depicts an excerpt of the input table:

W31105 -1.233177 -3.490964 1.909774 -1.484107 -0.530059 -0.288270 1.743168 -1.370971 -2.089372 -0.942689 ... -1.575776 0.464548 0.166464 0.470086 0.999315 0.081813 -1.461119 0.009120 1.387228 -1.308720
W31106 5.802759 1.664820 0.029968 -1.420301 1.231161 0.887331 1.160496 -0.268688 -0.145290 -0.620861 ... 0.037120 0.006508 0.874887 0.170908 0.631572 0.652908 0.785620 -0.139065 -0.111727 -1.108846
W31107 -5.798246 1.549633 -0.080099 0.738449 -0.551667 1.170946 -0.120539 -0.929517 -0.421885 -0.148952 ... 0.485965 -0.230995 -0.158994 -0.037129 -0.187478 0.069017 -0.315388 0.000821 0.243239 -0.119337
W31108 4.912524 2.162220 2.494962 0.615529 0.100325 2.048262 0.025773 0.511106 0.106394 0.007073 ... -0.360207 -0.187789 0.512162 0.036460 -0.083043 0.025471 -0.791394 0.075791 -0.262446 -0.603879
W31109 -5.867806 1.969801 -0.327709 0.639971 -0.684827 1.480482 0.323854 -0.438390 -0.019980 -0.432133 ... 0.756776 0.077329 -0.294401 -0.188781 0.025858 -0.784789 0.269851 0.277555 0.101574 -0.243650

Running MILK

In terminal:

milk -i input.csv

Below provides an example of what the standard output when running MILK should look like.

[ Info: MILK
[ Info: Using 1 worker(s) for distributed processing
[ Info: Parsed Arguments:
[ Info:   label: nothing
[ Info:   percentile: 1.0
[ Info:   batch-size: 50
[ Info:   output-dir: ./milk.out
[ Info:   merge-threshold: 100000
[ Info:   sample-size: 1
[ Info:   threads: 1
[ Info:   cache-size-limit: 50000
[ Info:   job-time: 24:00:00
[ Info:   group-stratification-mode: false
[ Info:   job-name: milk
[ Info:   partition-size: 10000
[ Info:   hpc-mode: false
[ Info:   skip-reconstruction: false
[ Info:   metric: euclidean
[ Info:   verbose: false
[ Info:   job-scheduler: nothing
[ Info:   environment-path: nothing
[ Info:   force-overwrite: true
[ Info:   job-account: nothing
[ Info:   job-memory: 4
[ Info:   seed: 21
[ Info:   input-path: input.csv
[ Info: ==========
[ Info: Iteration: 0 (2730 objects)
[ Info: 	Caching 2730 objects from the following path: /home/brett/milk/tutorial/milk.out/input.iteration_00000000.input.csv
[ Info: ========== Starting encapsulated recursions...
[ Info: Iteration: 0 (2730 objects)
[ Info: 	[input.iteration_00000000] 2040 groups (144 groups optimized); 2730 objects (2730 total); threshold: 3.491 (computed); 3725921 comparisons (1 CPU(s); cache); runtime: 0.01 min
[ Info: 	2040 objects after recursive iteration.
[ Info: Iteration: 1 (2040 objects)
[ Info: 	[input.iteration_00000001] 1273 groups (176 groups optimized); 2040 objects (2730 total); threshold: 4.744 (computed); 2959095 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.01 min
[ Info: 	1273 objects after recursive iteration.
[ Info: Iteration: 2 (1273 objects)
[ Info: 	[input.iteration_00000002] 833 groups (108 groups optimized); 1273 objects (2730 total); threshold: 5.8528 (computed); 2023858 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 	833 objects after recursive iteration.

[ ...

[ Info: Iteration: 31 (2 objects)
[ Info: 	[input.iteration_00000031] 1 groups (1 groups optimized); 2 objects (2730 total); threshold: 15.7195 (computed); 2731 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 	1 objects after recursive iteration.
┌ Info: 
└ Completed recursive downsampling procedure.
[ Info: Reconstructing hierarchical graph...
[ Info: Hierarchical reconstruction
[ Info: 	Done!

Processing the MILK output

import numpy as np

from glob import glob
from tqdm import tqdm

# Visualization libraries
import seaborn as sns
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

"""
Constructing a BioPython tree object from the MILK output.
"""
from Bio import Phylo
from Bio.Phylo.BaseTree import Tree,Clade
from collections import defaultdict,deque

The following functions can be used to create the tree object.

def instantiate_node(node_id,info_dict):
    node = Clade(name=node_id)
    node_dict = info_dict[node_id]
    node.representative_id = node_dict["representative_id"]
    node.group_size = node_dict["group_size"]
    node.iteration = node_dict["iteration"]
    node.threshold = node_dict["threshold"]
    node.spread = node_dict["spread"]
    node.specificity = node_dict["specificity"]
    node.resolution = node_dict["resolution"]
    return node

def construct_tree_object(vertices_df,edges_df):
    """
    For each (parent) node, provide list of (children) subnodes (top is root; bottom refers to leaves)
    """
    descendents_dict = edges_df.groupby("source")["target"].apply(list).to_dict()
    max_iteration_df = vertices_df[vertices_df["iteration"] == vertices_df["iteration"].max()]

    root_id = max_iteration_df["node_id"].values[0]
    node_info_dict = vertices_df.set_index("node_id").to_dict(orient="index")

    root_node = instantiate_node(root_id,node_info_dict)
    queue = deque([root_node])
    while queue:
        clade = queue.popleft()
        subclades = []
        for sample_id in descendents_dict.get(clade.name,[]):
            subclade = instantiate_node(sample_id,node_info_dict)
            subclades.append(subclade)
            queue.append(subclade)
        clade.clades = subclades

    return Tree(root_node)
# MILK output directory contains a vertices and edges table
vertices_path = os.path.join(".","milk.out","output","vertices.csv.gz")
vertices_df = pd.read_csv(vertices_path)

edges_path = os.path.join(".","milk.out","output","edges.csv.gz")
edges_df = pd.read_csv(edges_path)

Table with vertex (i.e., node) information:

node_id representative_id group_size iteration threshold spread specificity resolution
0 W39010 W39010 1 0 0.0 NaN NaN 0
1 W31116 W31116 1 0 0.0 NaN NaN 0
2 W38090 W38090 1 0 0.0 NaN NaN 0
3 W37487 W37487 1 0 0.0 NaN NaN 0
4 W38983 W38983 1 0 0.0 NaN NaN 0

Table including edge information:

source target
0 I1 W39010
1 I1 W31116
2 I1 W38090
3 I1 W37487
4 I1 W38983

Refer to the documentation for more information on the output.

Creating the tree object:

tree = construct_tree_object(vertices_df,edges_df)

Visualizing the MILK tree in R (ggraph package)

library(tidyverse)
library(ggraph)
library(igraph)
library(paletteer)
milk_dir <- "/home/brett/milk/tutorial/milk.out"
vertices_path <- file.path(milk_dir,"output","vertices.csv.gz")
edges_path <- file.path(milk_dir,"output","edges.csv.gz")
metadata_path <- "/home/brett/milk/tutorial/metadata.csv"

Note the metadata.csv table is derived via the following call: adata.obs.to_csv('metadata.csv').

metadata_df <- read.csv(metadata_path) %>% dplyr::rename(representative_id=X)
representative_idpaul15_clustersn_genesleidenpaul15_clusters_colors
<chr><chr><int><int><chr>
1W311057MEP 2771#e377c2
2W3110615Mo10213#c49c94
3W311073Ery14660#279e68
4W3110815Mo14154#c49c94
5W311093Ery15050#279e68
6W3111015Mo 7952#c49c94

Loading the vertices and edges table into R, but now joining the metadata information to the vertex table. Here, the metadata information is associated to internal nodes via the group’s representative ID, which corresponds to the medoid of leaves for the respective clade defined by the internal node.

vertices_df <- read.csv(vertices_path) %>% left_join(metadata_df,by="representative_id")
edges_df <- read.csv(edges_path)
g <- graph_from_data_frame(edges_df,vertices=vertices_df)

Visualization:

cluster_color_map <- vertices_df %>%
    dplyr::select(paul15_clusters, paul15_clusters_colors) %>%
    distinct() %>%
    tibble::deframe()

p <- ggraph(g,layout="dendrogram",circular=T) + 
    geom_edge_elbow(color="lightgray",width=0.5,alpha=0.75) +
    geom_node_point(aes(color=paul15_clusters,size=group_size),stroke=0,alpha=0.5) +
    scale_color_manual(
        values = cluster_color_map,
        name = "Label",
        guide = guide_legend(override.aes=list(size=6)
    )
    ) +
    scale_size(range=c(2,16),name="Group size") +
    coord_cartesian(clip="off") +
    theme_minimal() + theme(
        legend.position="right",
        legend.title = element_text(size=16),
        legend.text = element_text(size=16),
        panel.grid = element_blank(),
        axis.text = element_blank(),
        axis.title = element_blank(),
        plot.title = element_text(size=16,hjust=0.5),
        aspect.ratio=1
    )

options(repr.plot.width=16,repr.plot.height=16)
print(p)

png

Downstream analyses utilizing the MILK tree output

# cell ID -> cell type
labels_dict = {}
for sample_id,label in zip(adata.obs_names,adata.obs["paul15_clusters"]):
    labels_dict[sample_id] = label

'''
This is just a count of the total number of cells of each cell type in the dataset.
'''
total_counts_dict = defaultdict(int)
for tip in tree.get_terminals():
    total_counts_dict[labels_dict[tip.representative_id]] += 1

The total_counts_dict summarizes the number of cells per cell type:

Cell Type

Count

10GMP

153

9GMP

63

13Baso

300

12Baso

69

14Mo

373

7MEP

167

17Neu

22

11DC

30

15Mo

186

16Neu

164

18Eos

9

19Lymph

31

3Ery

246

4Ery

124

2Ery

329

6Ery

173

5Ery

180

1Ery

43

8Mk

68

''' For each clade, we can then get a clade-specific count of cell types
'''
for clade in tree.find_clades():
    clade_counts_dict = defaultdict(int)
    for tip in clade.get_terminals():
        clade_counts_dict[labels_dict[tip.representative_id]] += 1
    clade.clade_counts_dict = clade_counts_dict

We can further normalize the clade-specific counts with respect to each cell type by the total counts, storing the relative proportion/purity information per clade in a table as follows:

treedepth_dict = tree.depths(unit_branch_lengths=True)
results_dict = {
    "clade_name":[],
    "clade_size":[],
    "label":[],
    "clade_count":[],
    "prop":[],
    "purity":[],
    "treedepth":[]
}
for clade in tree.find_clades():
    for label,clade_count in clade.clade_counts_dict.items():
        results_dict["clade_name"].append(clade.name)
        results_dict["clade_size"].append(clade.group_size)
        results_dict["label"].append(label)
        results_dict["clade_count"].append(clade_count)
        results_dict["prop"].append(clade_count/total_counts_dict[label])
        results_dict["purity"].append(clade_count/clade.group_size)
        results_dict["treedepth"].append(treedepth_dict[clade])

results_df = pd.DataFrame(results_dict)

The results_df data frame looks like:

clade_name clade_size label clade_count prop purity treedepth
0 I5973 2730 10GMP 153 1.0 0.056044 0
1 I5973 2730 9GMP 63 1.0 0.023077 0
2 I5973 2730 13Baso 300 1.0 0.109890 0
3 I5973 2730 12Baso 69 1.0 0.025275 0
4 I5973 2730 14Mo 373 1.0 0.136630 0

Weighted average depth of each cell type in the MILK tree

For each cell type, we can compute the weighted average tree depth across all clades that capture the cell type (to a threshold of 0.10). The product of the relative proportion and purity of that the cell type label of interest across clades was used as the weight.

results_dict = {
    "cell_type": [],
    "wavg_treedepth": []
}

'''
Only considering clades that capture at least 10% of the population.
This is primarily to denoise the noisy signal present at the level of
individual cells.
'''
threshold = 0.1

# compute weighted average treedepth per label
for label, subset_df in results_df.groupby("label"):
    _subset_df = subset_df[subset_df["prop"] >= threshold]

    wavg_depth = np.average(
      _subset_df["treedepth"],
      weights=np.multiply(_subset_df["prop"],_subset_df["purity"])
    )
  
    results_dict["cell_type"].append(label)
    results_dict["wavg_treedepth"].append(wavg_depth)

df_depth = pd.DataFrame(results_dict)
df_depth = df_depth.sort_values("wavg_treedepth") # sort by weighted average treedepth

plt.figure(figsize=(6,4))
sns.barplot(data=df_depth,x="cell_type",y="wavg_treedepth",hue="cell_type",palette="Spectral")
plt.xticks(rotation=90)
plt.ylabel("Weighted average tree depth")
plt.xlabel("Cell type")
plt.tight_layout()
plt.show()

png

This shows a cell type-specific bias in the MILK tree topology, where a lower depth corresponds to branching higher (towards the root) in the MILK tree and a lower depth is more towards the leaves.

Calculating the number of (productive) branching events from each cell (leaf) to the root

A productive branching event is defined as an internal node with at least 2 children. The resulting count values can be stratified by cell type group and normalized by the total number of cells per group.

This analysis offers a slightly different interpretation, since the number of hierarchical branching events reflects the degree of lineage “fragmentation” in the inferred topology. For example, higher values indicate that a cell type tends to reside in regions of the tree with more successive subdivisions, suggesting a more refined or partitioned lineage structure.

leaf_branching_counts_dict = {}
for tip in tqdm(tree.get_terminals()):
    path = tree.get_path(tip.name) # returns list of nodes from leaf ("tip") to root
    leaf_branching_counts_dict[tip.name] = sum(len(clade.clades) >= 2 for clade in path)

df_cells = pd.DataFrame([
    {"cell": k, "branching_count": v, "label": labels_dict[k]}
    for k, v in leaf_branching_counts_dict.items()
])

# order by median branching count per label (i.e., cell type)
median_order = df_cells.groupby("label")["branching_count"].median().sort_values().index

plt.figure(figsize=(6,4))
sns.boxplot(data=df_cells, x="label", y="branching_count",order=median_order)  # order by median
plt.xticks(rotation=90)
plt.xlabel("Cell type")
plt.ylabel("Number of branching events")
plt.show()

png

Threshold-based quantification cell type substructure in the MILK tree topology

An additional metric that may be useful is to quantify the average clade size required to robustly capture each cell type group across a range of thresholds. For a given threshold, this reflects how large a clade must be to achieve a specified level of label homogeneity, providing insight into how dispersed or consolidated a cell type is within the tree topology.

Evaluating this quantity across a continuum of thresholds produces a sigmoidl curve for each cell type. The resulting curves can be directly compared across cell types (e.g., shifts), and summarized (e.g., via AUC) to quantify relative topological separability within the tree.

results_dict = {
    "label":[],
    "threshold":[],
    "wavg_clade_size":[]
}

for threshold in tqdm(np.arange(0,1.01,0.01)):
    for label,subset_df in results_df.groupby("label"):
        _subset_df = subset_df[subset_df["prop"] >= threshold]
        if _subset_df.shape[0] == 0: continue
        results_dict["label"].append(label)
        results_dict["threshold"].append(threshold)
        results_dict["wavg_clade_size"].append(np.average(_subset_df["clade_size"],weights=_subset_df["purity"]))

df = pd.DataFrame(results_dict)

The resulting data frame:

label threshold wavg_clade_size
0 10GMP 0.0 6.948723
1 11DC 0.0 5.044474
2 12Baso 0.0 4.926507
3 13Baso 0.0 11.584472
4 14Mo 0.0 13.294418
plt.figure(figsize=(6,4))
sns.lineplot(data=df, x="threshold", y="wavg_clade_size", hue="label")
plt.xlabel("Cell type capture threshold (relative proportion)")
plt.ylabel("Weighted average clade size")
plt.title("Each line corresponds to a cell type")
plt.legend(bbox_to_anchor=(1.05,1))
plt.show()

png

# Summarizing with a normalized AUC value
auc_dict = {}
for label in df["label"].unique():
    sub = df[df["label"] == label]
    sub = sub.sort_values("threshold")
    auc = np.trapz(sub["wavg_clade_size"], x=sub["threshold"])
    auc_dict[label] = auc/total_counts_dict[label]
auc_df = pd.DataFrame({
    "label": list(auc_dict.keys()),
    "auc_per_cell": list(auc_dict.values())
})
auc_df = auc_df.sort_values("auc_per_cell")

plt.figure(figsize=(6,4))
sns.barplot(data=auc_df, x="label", y="auc_per_cell",hue="label",palette="viridis")
plt.xticks(rotation=90)
plt.ylabel("AUC (weighted clade size)")
plt.xlabel("Cell type")
plt.yscale("log")
plt.title("Label separability based on MILK tree clade structure")
plt.tight_layout()
plt.show()

png

Overall, the resulting curves capture how well a given cell type is captured by the tree topology.

Cell types that maintain relatively large clade sizes across a wide range of thresholds indicate poor capture of that subpopulation by the MILK tree. This may be interesting from the standpoint of identifying labels that appear to be more dispersed throughout the topology, which could potentially identify poor quality labels.

If a cell type can be robustly captured by a small clade size across a range of thresholds, it indicates stronger topological coherence (i.e., separability within the hierarchy).

Thus, differences in curve shape and AUC provide a quantitative measure of how distinctly each cell type is organized in the inferred single-cell topology.