MILK tutorial: Network visualization of MILK trees PBMC (3k dataset)

This notebook provides a basic walkthrough of how MILK can be applied to scRNA-seq datasets. Specifically, using the processed 3k PBMCs from 10X Genomics obtained from the “Preprocessing and clustering 3k PBMCs (legacy workflow)”, which is provided by Scanpy.

Importing libraries

import os
import pathlib
import numpy as np
import pandas as pd
import scanpy as sc

Loading the dataset

adata = sc.datasets.pbmc3k_processed()

The processed anndata object contains the following information:

AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
    var: 'n_cells'
    uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

Visualizing the processed dataset, colored by annotated cell types:

sc.pl.umap(adata,color="louvain")

png

MILK input

We can apply MILK to the PCA embeddings (50 components) of the processed dataset.

input_df = pd.DataFrame(adata.obsm["X_pca"],index=adata.obs_names)
input_df.to_csv("input.csv",header=False)
AAACATACAACCAC-1 5.556233 0.257714 -0.186810 2.800131 -0.033783 -0.189702 0.310228 -1.323691 2.691945 0.125928 ... -0.266174 1.024464 -0.709844 -0.052780 -0.686898 -1.419867 -2.865078 0.027601 2.671032 -0.297620
AAACATTGAGCTAC-1 7.209530 7.481985 0.162706 -8.018575 3.012900 0.322293 2.270888 -0.605055 -0.905611 1.225260 ... 0.158161 0.819037 0.578912 -1.169742 0.955408 0.068133 -0.883082 2.930932 0.354197 -1.081801
AAACATTGATCAGC-1 2.694438 -1.583658 -0.663126 2.205649 -1.686360 -1.965395 -1.894999 -1.522103 1.914985 -0.481202 ... -1.054254 0.805932 1.543282 1.504834 -0.831818 -0.236549 1.883515 1.084782 0.381470 0.064662
AAACCGTGCTTCCG-1 -10.143295 -1.368530 1.209812 -0.700096 -2.872336 0.230617 1.278005 0.487900 -0.447965 -0.328465 ... 1.297246 0.611073 -0.007878 -0.648735 0.543566 3.156763 1.691134 -0.301377 -0.225427 0.962879
AAACCGTGTATGCG-1 -1.112816 -8.152788 1.332405 -4.252473 2.036407 5.597797 -0.110658 -0.102257 0.014520 0.581409 ... 1.191032 1.042533 1.734694 -0.142114 0.586381 0.636326 -1.451625 1.809683 -0.087072 -0.737833

MILK execution

Exporting MILK directory to PATH

In terminal:

milk -i input.csv

The following is an excerpt of the standard output:

[ Info: MILK
[ Info: Using 1 worker(s) for distributed processing
[ Info: Parsed Arguments:
[ Info:   label: nothing
[ Info:   percentile: 1.0
[ Info:   batch-size: 50
[ Info:   output-dir: ./milk.out
[ Info:   merge-threshold: 100000
[ Info:   sample-size: 1
[ Info:   threads: 1
[ Info:   cache-size-limit: 50000
[ Info:   job-time: 24:00:00
[ Info:   group-stratification-mode: false
[ Info:   job-name: milk
[ Info:   partition-size: 10000
[ Info:   hpc-mode: false
[ Info:   skip-reconstruction: false
[ Info:   metric: euclidean
[ Info:   verbose: false
[ Info:   job-scheduler: nothing
[ Info:   environment-path: nothing
[ Info:   force-overwrite: true
[ Info:   job-account: nothing
[ Info:   job-memory: 4
[ Info:   seed: 21
[ Info:   input-path: input.csv
[ Info: ==========
[ Info: Iteration: 0 (2638 objects)
[ Info: 	Caching 2638 objects from the following path: /home/brett/milk/tutorial/milk.out/input.iteration_00000000.input.csv
[ Info: ========== Starting encapsulated recursions...
[ Info: Iteration: 0 (2638 objects)
[ Info: 	[input.iteration_00000000] 1723 groups (257 groups optimized); 2638 objects (2638 total); threshold: 10.4644 (computed); 3479378 comparisons (1 CPU(s); cache); runtime: 0.01 min
[ Info: 	1723 objects after recursive iteration.
[ Info: Iteration: 1 (1723 objects)
[ Info: 	[input.iteration_00000001] 958 groups (220 groups optimized); 1723 objects (2638 total); threshold: 12.3272 (computed); 2361059 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 	958 objects after recursive iteration.

[ ...

[ Info: Iteration: 31 (2 objects)
[ Info: 	[input.iteration_00000031] 1 groups (1 groups optimized); 2 objects (2638 total); threshold: 50.5063 (computed); 2639 comparisons (1 CPU(s); cache; previous_groupings); runtime: 0.0 min
[ Info: 	1 objects after recursive iteration.
┌ Info: 
└ Completed recursive downsampling procedure.
[ Info: Reconstructing hierarchical graph...
[ Info: Hierarchical reconstruction
[ Info: 	Done!

Visualization of the single-cell MILK tree embedding

import numpy as np

from glob import glob
from tqdm import tqdm

import graph_tool.all as gt

import seaborn as sns
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors

from Bio import Phylo
from Bio.Phylo.BaseTree import Tree,Clade
from collections import defaultdict,deque
gt.openmp_set_num_threads(1)

These Python functions construct a BioPython tree object from the MILK output files:

def instantiate_node(node_id,info_dict):
    node = Clade(name=node_id)
    node_dict = info_dict[node_id]
    node.representative_id = node_dict["representative_id"]
    node.group_size = node_dict["group_size"]
    node.iteration = node_dict["iteration"]
    node.threshold = node_dict["threshold"]
    node.spread = node_dict["spread"]
    node.specificity = node_dict["specificity"]
    node.resolution = node_dict["resolution"]
    return node

def construct_tree_object(vertices_df,edges_df):
    """
    For each (parent) node, provide list of (children) subnodes
    (top is root; bottom refers to leaves)
    """
    descendents_dict = edges_df.groupby("source")["target"].apply(list).to_dict()
    max_iteration_df = vertices_df[vertices_df["iteration"] == vertices_df["iteration"].max()]

    root_id = max_iteration_df["node_id"].values[0]
    node_info_dict = vertices_df.set_index("node_id").to_dict(orient="index")

    root_node = instantiate_node(root_id,node_info_dict)
    queue = deque([root_node])
    while queue:
        clade = queue.popleft()
        subclades = []
        for sample_id in descendents_dict.get(clade.name,[]):
            subclade = instantiate_node(sample_id,node_info_dict)
            subclades.append(subclade)
            queue.append(subclade)
        clade.clades = subclades

    return Tree(root_node)
# MILK output directory contains a vertices and edges table
vertices_path = os.path.join(".","milk.out","output","vertices.csv.gz")
vertices_df = pd.read_csv(vertices_path)

edges_path = os.path.join(".","milk.out","output","edges.csv.gz")
edges_df = pd.read_csv(edges_path)

Table with vertex (i.e., node) information:

node_id representative_id group_size iteration threshold spread specificity resolution
0 CAATCGGAGAAACA-1 CAATCGGAGAAACA-1 1 0 0.000000 NaN NaN 0
1 TCGGTAGAGTAGGG-1 TCGGTAGAGTAGGG-1 1 0 0.000000 NaN NaN 0
2 TAGCCGCTTACGAC-1 TAGCCGCTTACGAC-1 1 0 0.000000 NaN NaN 0
3 I1 CAATCGGAGAAACA-1 3 1 10.464395 9.64075 6.0 2
4 CCTCGAACACTTTC-1 CCTCGAACACTTTC-1 1 0 0.000000 NaN NaN 0

Table including edge information:

source target
0 I1 CAATCGGAGAAACA-1
1 I1 TCGGTAGAGTAGGG-1
2 I1 TAGCCGCTTACGAC-1
3 I2 CCTCGAACACTTTC-1
4 I3 GAACCTGATGAACC-1

Refer to the documentation for more information on the output.

Creating the tree object:

tree = construct_tree_object(vertices_df,edges_df)

Cell types annotated in this dataset:

labels = sorted(adata.obs["louvain"].unique())
  • B cells

  • CD14+ Monocytes

  • CD4 T cells

  • CD8 T cells

  • Dendritic cells

  • FCGR3A+ Monocytes

  • Megakaryocytes

  • NK cells

labels_dict = {}
for sample_id,label in zip(adata.obs_names,adata.obs["louvain"]):
    labels_dict[sample_id] = label
color_dict = {}
for label,col in zip(adata.obs['louvain'].cat.categories,adata.uns["louvain_colors"]):
    color_dict[label] = tuple(list(mcolors.to_rgb(col))+[0.75])
neutral_col = (0.75,0.75,0.75,0.25)
neutral_threshold = 0.8

g = gt.Graph(directed=False)

vertex_labels = g.new_vertex_property("vector<float>")
vertex_sizes = g.new_vertex_property("float")
edge_widths = g.new_edge_property("int")

root = g.add_vertex()
vertex_labels[root] = neutral_col # color_dict[labels_dict[tree.clade.representative_id]]
vertex_sizes[root] = np.log1p(tree.count_terminals())*2

stack = [(root,tree.clade)]
while stack:
    vertex,clade = stack.pop()
    for subclade in clade:
        subvertex = g.add_vertex()
        edge = g.add_edge(vertex,subvertex)
        representative_label = labels_dict[subclade.representative_id]
        representative_col = color_dict[representative_label]
        leaf_labels = []
        for leaf in subclade.get_terminals():
            leaf_labels.append(labels_dict[leaf.representative_id])
        prop_matching = np.mean([int(representative_label == label) for label in leaf_labels])

        '''
        I'm requiring at least 0.8 of the leaves (associated with the 
        respective clade) to have the same cell type in order for the clade
        to be colored. Otherwise, it will just be in gray. This can just help
        to visualize complex tree structures.
        '''
        if prop_matching >= neutral_threshold:
            vertex_labels[subvertex] = color_dict[labels_dict[subclade.representative_id]]
        else:
            vertex_labels[subvertex] = neutral_col
        vertex_sizes[subvertex] = np.log1p(subclade.count_terminals())*2
        edge_widths[edge] = np.log1p(subclade.count_terminals())

        stack.append((subvertex,subclade))

g.vertex_properties["cell_type"] = vertex_labels
g.vertex_properties["group_size"] = vertex_sizes
g.edge_properties["pen_width"]    = edge_widths

Scale force-directed network layout provided by the graph_tool Python library:

pos = gt.sfdp_layout(g,C=10)
gt.graph_draw(
    g,
    pos=pos,
    vertex_fill_color=vertex_labels,
    vertex_color=[1,1,1,0.5],
    edge_color=[0.75,0.75,0.75,0.25],
    edge_pen_width=edge_widths,
    vertex_size=vertex_sizes,
    output_size=(400,600),
    bg_color=None
)

png

We can compare it back to the UMAP embeddings.

sc.pl.umap(adata,color="louvain")

png