scrna5/6 Jupyter Notebook lamindata

Train a machine learning model on a collection#

Here, we iterate over the artifacts within a collection to train a machine learning model at scale.

import lamindb as ln
import anndata as ad
import numpy as np
๐Ÿ’ก connected lamindb: testuser1/test-scrna
ln.settings.transform.stem_uid = "Qr1kIHvK506r"
ln.settings.transform.version = "1"
ln.track()
๐Ÿ’ก notebook imports: anndata==0.9.2 lamindb==0.69.4 numpy==1.26.4 torch==2.2.2
๐Ÿ’ก saved: Transform(uid='Qr1kIHvK506r5zKv', name='Train a machine learning model on a collection', key='scrna5', version='1', type='notebook', updated_at=2024-03-31 21:40:36 UTC, created_by_id=1)
๐Ÿ’ก saved: Run(uid='qTF6dWpYb0EXdqQ6Obe0', transform_id=5, created_by_id=1)

Query our collection:

collection = ln.Collection.filter(
    name="My versioned scRNA-seq collection", version="2"
).one()
collection.describe()
Hide code cell output
Collection(uid='Prm7WKmDwQXfBiw2m3Z7', name='My versioned scRNA-seq collection', version='2', hash='HNR3VFV60_yqRnUka11E', visibility=1, updated_at=2024-03-31 21:40:16 UTC)

Provenance:
  ๐Ÿ“” transform: Transform(uid='ManDYgmftZ8C5zKv', name='Standardize and append a batch of data', key='scrna2', version='1', type='notebook', updated_at=2024-03-31 21:39:57 UTC, created_by_id=1)
  ๐Ÿ‘ฃ run: Run(uid='dN0wI7NwQlaUEtrm90vU', started_at=2024-03-31 21:39:57 UTC, is_consecutive=True, transform_id=2, created_by_id=1)
  ๐Ÿ‘ค created_by: User(uid='DzTjkKse', handle='testuser1', name='Test User1', updated_at=2024-03-31 21:38:00 UTC)
  โฌ‡๏ธ input_of (core.Run): ['2024-03-31 21:40:26 UTC']
Features:
  var: FeatureSet(uid='gaeFlVcQ4kBI4j5Q2l3e', n=36508, type='number', registry='bionty.Gene', hash='b5NMddLHEyZqn-vSYvBI', updated_at=2024-03-31 21:40:14 UTC, created_by_id=1)
    'MIR1302-2HG', 'FAM138A', 'OR4F5', 'None', 'None', 'None', 'None', 'None', 'None', 'None', 'OR4F29', 'None', 'OR4F16', 'None', 'LINC01409', 'FAM87B', 'LINC01128', 'LINC00115', 'FAM41C', 'None', ...
  obs: FeatureSet(uid='kyZGguwRcMDXuL5S7ewi', n=4, registry='core.Feature', hash='Mo__GXLUrqCUMrTaqSUj', updated_at=2024-03-31 21:39:50 UTC, created_by_id=1)
    ๐Ÿ”— donor (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
    ๐Ÿ”— tissue (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
    ๐Ÿ”— cell_type (40, bionty.CellType): 'dendritic cell', 'effector memory CD4-positive, alpha-beta T cell, terminally differentiated', 'cytotoxic T cell', 'CD8-positive, CD25-positive, alpha-beta regulatory T cell', 'CD14-positive, CD16-negative classical monocyte', 'CD38-positive naive B cell', 'B cell, CD19-positive', 'CD4-positive, alpha-beta T cell', 'classical monocyte', 'T follicular helper cell', ...
    ๐Ÿ”— assay (3, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1'
Labels:
  ๐Ÿท๏ธ tissues (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
  ๐Ÿท๏ธ cell_types (40, bionty.CellType): 'dendritic cell', 'effector memory CD4-positive, alpha-beta T cell, terminally differentiated', 'cytotoxic T cell', 'CD8-positive, CD25-positive, alpha-beta regulatory T cell', 'CD14-positive, CD16-negative classical monocyte', 'CD38-positive naive B cell', 'B cell, CD19-positive', 'CD4-positive, alpha-beta T cell', 'classical monocyte', 'T follicular helper cell', ...
  ๐Ÿท๏ธ experimental_factors (3, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1'
  ๐Ÿท๏ธ ulabels (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...

Create a map-style dataset#

Let us create a map-style dataset using using mapped(): a MappedCollection. This is what, for example, the PyTorch DataLoader expects as an input.

Under-the-hood, it performs a virtual inner join of the features of the underlying AnnData objects and thus allows to work with very large collections.

You can either perform a virtual inner join:

with collection.mapped(label_keys=["cell_type"], join="inner") as dataset:
    print(len(dataset.var_joint))
749

Or a virtual outer join:

dataset = collection.mapped(label_keys=["cell_type"], join="outer")
len(dataset.var_joint)
36508

This is compatible with a PyTorch DataLoader because it implements __getitem__ over a list of backed AnnData objects. The 5th cell in the collection can be accessed like:

dataset[5]
Hide code cell output
{'x': array([ 0.   ,  0.   ,  0.   , ...,  0.   ,  0.   , -0.456], dtype=float32),
 '_storage_idx': 0,
 'cell_type': 5}

The labels are encoded into integers:

dataset.encoders
Hide code cell output
{'cell_type': {'effector memory CD4-positive, alpha-beta T cell': 0,
  'alpha-beta T cell': 1,
  'lymphocyte': 2,
  'macrophage': 3,
  'naive thymus-derived CD8-positive, alpha-beta T cell': 4,
  'cytotoxic T cell': 5,
  'plasmacytoid dendritic cell': 6,
  'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 7,
  'CD38-positive naive B cell': 8,
  'plasma cell': 9,
  'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 10,
  'classical monocyte': 11,
  'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 12,
  'naive B cell': 13,
  'CD14-positive, CD16-negative classical monocyte': 14,
  'CD4-positive helper T cell': 15,
  'germinal center B cell': 16,
  'CD8-positive, alpha-beta memory T cell': 17,
  'progenitor cell': 18,
  'non-classical monocyte': 19,
  'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 20,
  'megakaryocyte': 21,
  'naive thymus-derived CD4-positive, alpha-beta T cell': 22,
  'plasmablast': 23,
  'conventional dendritic cell': 24,
  'animal cell': 25,
  'regulatory T cell': 26,
  'CD16-positive, CD56-dim natural killer cell, human': 27,
  'dendritic cell': 28,
  'memory B cell': 29,
  'mucosal invariant T cell': 30,
  'T follicular helper cell': 31,
  'gamma-delta T cell': 32,
  'CD16-negative, CD56-bright natural killer cell, human': 33,
  'mast cell': 34,
  'dendritic cell, human': 35,
  'CD4-positive, alpha-beta T cell': 36,
  'B cell, CD19-positive': 37,
  'alveolar macrophage': 38,
  'group 3 innate lymphoid cell': 39}}

Create a pytorch DataLoader#

Let us use a weighted sampler:

from torch.utils.data import DataLoader, WeightedRandomSampler

# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)

We can now iterate through the data loader:

for batch in dataloader:
    pass

Close the connections in MappedCollection:

dataset.close()
In practice, use a context manager
with collection.mapped(label_keys=["cell_type"]) as dataset:
    sampler = WeightedRandomSampler(
        weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
    )
    dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
    for batch in dataloader:
        pass