MILK tutorial: Network visualization of MILK trees PBMC (3k dataset)
This notebook provides a basic walkthrough of how MILK can be applied to scRNA-seq datasets. Specifically, using the processed 3k PBMCs from 10X Genomics obtained from the “Preprocessing and clustering 3k PBMCs (legacy workflow)”, which is provided by Scanpy.
Importing libraries
import os
import pathlib
import numpy as np
import pandas as pd
import scanpy as sc
Loading the dataset
adata = sc.datasets.pbmc3k_processed()
The processed anndata object contains the following information:
AnnData object with n_obs × n_vars = 2638 × 1838
obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
var: 'n_cells'
uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
varm: 'PCs'
obsp: 'distances', 'connectivities'
Visualizing the processed dataset, colored by annotated cell types:
sc.pl.umap(adata,color="louvain")

MILK input
We can apply MILK to the PCA embeddings (50 components) of the processed dataset.
input_df = pd.DataFrame(adata.obsm["X_pca"],index=adata.obs_names)
input_df.to_csv("input.csv",header=False)
| AAACATACAACCAC-1 | 5.556233 | 0.257714 | -0.186810 | 2.800131 | -0.033783 | -0.189702 | 0.310228 | -1.323691 | 2.691945 | 0.125928 | ... | -0.266174 | 1.024464 | -0.709844 | -0.052780 | -0.686898 | -1.419867 | -2.865078 | 0.027601 | 2.671032 | -0.297620 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| AAACATTGAGCTAC-1 | 7.209530 | 7.481985 | 0.162706 | -8.018575 | 3.012900 | 0.322293 | 2.270888 | -0.605055 | -0.905611 | 1.225260 | ... | 0.158161 | 0.819037 | 0.578912 | -1.169742 | 0.955408 | 0.068133 | -0.883082 | 2.930932 | 0.354197 | -1.081801 |
| AAACATTGATCAGC-1 | 2.694438 | -1.583658 | -0.663126 | 2.205649 | -1.686360 | -1.965395 | -1.894999 | -1.522103 | 1.914985 | -0.481202 | ... | -1.054254 | 0.805932 | 1.543282 | 1.504834 | -0.831818 | -0.236549 | 1.883515 | 1.084782 | 0.381470 | 0.064662 |
| AAACCGTGCTTCCG-1 | -10.143295 | -1.368530 | 1.209812 | -0.700096 | -2.872336 | 0.230617 | 1.278005 | 0.487900 | -0.447965 | -0.328465 | ... | 1.297246 | 0.611073 | -0.007878 | -0.648735 | 0.543566 | 3.156763 | 1.691134 | -0.301377 | -0.225427 | 0.962879 |
| AAACCGTGTATGCG-1 | -1.112816 | -8.152788 | 1.332405 | -4.252473 | 2.036407 | 5.597797 | -0.110658 | -0.102257 | 0.014520 | 0.581409 | ... | 1.191032 | 1.042533 | 1.734694 | -0.142114 | 0.586381 | 0.636326 | -1.451625 | 1.809683 | -0.087072 | -0.737833 |
MILK execution
Exporting MILK directory to PATH
In terminal:
milk -i input.csv
The following is an excerpt of the standard output:
[ Info: MILK
[ Info: Using 1 worker(s) for distributed processing
[ Info: Parsed Arguments:
[ Info: label: nothing
[ Info: percentile: 1.0
[ Info: batch-size: 50
[ Info: output-dir: ./milk.out
[ Info: merge-threshold: 100000
[ Info: sample-size: 1
[ Info: threads: 1
[ Info: cache-size-limit: 50000
[ Info: job-time: 24:00:00
[ Info: group-stratification-mode: false
[ Info: job-name: milk
[ Info: partition-size: 10000
[ Info: hpc-mode: false
[ Info: skip-reconstruction: false
[ Info: metric: euclidean
[ Info: verbose: false
[ Info: job-scheduler: nothing
[ Info: environment-path: nothing
[ Info: force-overwrite: true
[ Info: job-account: nothing
[ Info: job-memory: 4
[ Info: seed: 21
[ Info: input-path: input.csv
[ Info: ==========
[ Info: Iteration: 0 (2638 objects)
[ Info: Caching 2638 objects from the following path: /home/brett/milk/tutorial/milk.out/input.iteration_00000000.input.csv
[ Info: ========== Starting encapsulated recursions...
[ Info: Iteration: 0 (2638 objects)
[ Info: [input.iteration_00000000] 1723 groups (257 groups optimized); 2638 objects (2638 total); threshold: 10.4644 (computed); 3479378 comparisons (1 CPU(s); cache); runtime: 0.01 min
[ Info: 1723 objects after recursive iteration.
[ Info: Iteration: 1 (1723 objects)
[ Info: [input.iteration_00000001] 958 groups (220 groups optimized); 1723 objects (2638 total); threshold: 12.3272 (computed); 2361059 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 958 objects after recursive iteration.
[ ...
[ Info: Iteration: 31 (2 objects)
[ Info: [input.iteration_00000031] 1 groups (1 groups optimized); 2 objects (2638 total); threshold: 50.5063 (computed); 2639 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 1 objects after recursive iteration.
┌ Info:
└ Completed recursive downsampling procedure.
[ Info: Reconstructing hierarchical graph...
[ Info: Hierarchical reconstruction
[ Info: Done!
Visualization of the single-cell MILK tree embedding
import numpy as np
from glob import glob
from tqdm import tqdm
import graph_tool.all as gt
import seaborn as sns
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
from Bio import Phylo
from Bio.Phylo.BaseTree import Tree,Clade
from collections import defaultdict,deque
gt.openmp_set_num_threads(1)
These Python functions construct a BioPython tree object from the MILK output files:
def instantiate_node(node_id,info_dict):
node = Clade(name=node_id)
node_dict = info_dict[node_id]
node.representative_id = node_dict["representative_id"]
node.group_size = node_dict["group_size"]
node.iteration = node_dict["iteration"]
node.threshold = node_dict["threshold"]
node.spread = node_dict["spread"]
node.specificity = node_dict["specificity"]
node.resolution = node_dict["resolution"]
return node
def construct_tree_object(vertices_df,edges_df):
"""
For each (parent) node, provide list of (children) subnodes
(top is root; bottom refers to leaves)
"""
descendents_dict = edges_df.groupby("source")["target"].apply(list).to_dict()
max_iteration_df = vertices_df[vertices_df["iteration"] == vertices_df["iteration"].max()]
root_id = max_iteration_df["node_id"].values[0]
node_info_dict = vertices_df.set_index("node_id").to_dict(orient="index")
root_node = instantiate_node(root_id,node_info_dict)
queue = deque([root_node])
while queue:
clade = queue.popleft()
subclades = []
for sample_id in descendents_dict.get(clade.name,[]):
subclade = instantiate_node(sample_id,node_info_dict)
subclades.append(subclade)
queue.append(subclade)
clade.clades = subclades
return Tree(root_node)
# MILK output directory contains a vertices and edges table
vertices_path = os.path.join(".","milk.out","output","vertices.csv.gz")
vertices_df = pd.read_csv(vertices_path)
edges_path = os.path.join(".","milk.out","output","edges.csv.gz")
edges_df = pd.read_csv(edges_path)
Table with vertex (i.e., node) information:
| node_id | representative_id | group_size | iteration | threshold | spread | specificity | resolution | |
|---|---|---|---|---|---|---|---|---|
| 0 | CAATCGGAGAAACA-1 | CAATCGGAGAAACA-1 | 1 | 0 | 0.000000 | NaN | NaN | 0 |
| 1 | TCGGTAGAGTAGGG-1 | TCGGTAGAGTAGGG-1 | 1 | 0 | 0.000000 | NaN | NaN | 0 |
| 2 | TAGCCGCTTACGAC-1 | TAGCCGCTTACGAC-1 | 1 | 0 | 0.000000 | NaN | NaN | 0 |
| 3 | I1 | CAATCGGAGAAACA-1 | 3 | 1 | 10.464395 | 9.64075 | 6.0 | 2 |
| 4 | CCTCGAACACTTTC-1 | CCTCGAACACTTTC-1 | 1 | 0 | 0.000000 | NaN | NaN | 0 |
Table including edge information:
| source | target | |
|---|---|---|
| 0 | I1 | CAATCGGAGAAACA-1 |
| 1 | I1 | TCGGTAGAGTAGGG-1 |
| 2 | I1 | TAGCCGCTTACGAC-1 |
| 3 | I2 | CCTCGAACACTTTC-1 |
| 4 | I3 | GAACCTGATGAACC-1 |
Refer to the documentation for more information on the output.
Creating the tree object:
tree = construct_tree_object(vertices_df,edges_df)
Cell types annotated in this dataset:
labels = sorted(adata.obs["louvain"].unique())
B cells
CD14+ Monocytes
CD4 T cells
CD8 T cells
Dendritic cells
FCGR3A+ Monocytes
Megakaryocytes
NK cells
labels_dict = {}
for sample_id,label in zip(adata.obs_names,adata.obs["louvain"]):
labels_dict[sample_id] = label
color_dict = {}
for label,col in zip(adata.obs['louvain'].cat.categories,adata.uns["louvain_colors"]):
color_dict[label] = tuple(list(mcolors.to_rgb(col))+[0.75])
neutral_col = (0.75,0.75,0.75,0.25)
neutral_threshold = 0.8
g = gt.Graph(directed=False)
vertex_labels = g.new_vertex_property("vector<float>")
vertex_sizes = g.new_vertex_property("float")
edge_widths = g.new_edge_property("int")
root = g.add_vertex()
vertex_labels[root] = neutral_col # color_dict[labels_dict[tree.clade.representative_id]]
vertex_sizes[root] = np.log1p(tree.count_terminals())*2
stack = [(root,tree.clade)]
while stack:
vertex,clade = stack.pop()
for subclade in clade:
subvertex = g.add_vertex()
edge = g.add_edge(vertex,subvertex)
representative_label = labels_dict[subclade.representative_id]
representative_col = color_dict[representative_label]
leaf_labels = []
for leaf in subclade.get_terminals():
leaf_labels.append(labels_dict[leaf.representative_id])
prop_matching = np.mean([int(representative_label == label) for label in leaf_labels])
'''
I'm requiring at least 0.8 of the leaves (associated with the
respective clade) to have the same cell type in order for the clade
to be colored. Otherwise, it will just be in gray. This can just help
to visualize complex tree structures.
'''
if prop_matching >= neutral_threshold:
vertex_labels[subvertex] = color_dict[labels_dict[subclade.representative_id]]
else:
vertex_labels[subvertex] = neutral_col
vertex_sizes[subvertex] = np.log1p(subclade.count_terminals())*2
edge_widths[edge] = np.log1p(subclade.count_terminals())
stack.append((subvertex,subclade))
g.vertex_properties["cell_type"] = vertex_labels
g.vertex_properties["group_size"] = vertex_sizes
g.edge_properties["pen_width"] = edge_widths
Scale force-directed network layout provided by the graph_tool Python library:
pos = gt.sfdp_layout(g,C=10)
gt.graph_draw(
g,
pos=pos,
vertex_fill_color=vertex_labels,
vertex_color=[1,1,1,0.5],
edge_color=[0.75,0.75,0.75,0.25],
edge_pen_width=edge_widths,
vertex_size=vertex_sizes,
output_size=(400,600),
bg_color=None
)

We can compare it back to the UMAP embeddings.
sc.pl.umap(adata,color="louvain")
