torch_em
torch_em is a library for deep learning in microscopy images. It supports segmentation and other relevant image analysis tasks. It is based on PyTorch.
Installation
From conda
You can use conda
(or its faster alternative mamba) to install torch_em
and its dependencies from the conda-forge
package:
conda install -c conda-forge torch_em
From source
To install torch_em
from source you should create a dedicated conda environment.
We provide the environment file environment.yaml
for this. After cloning the torch_em
repository,
you can set it up and then install torch_em
via:
conda env create -f environment.yaml
conda activate torch-em-dev
pip install -e .
Usage & Examples
torch_em
provides functionality for training deep neural networks, for segmentation tasks in torch_em.segmentation
and classification tasks in torch_em.classification
.
To customize the models and model training it implements multiple neural network architectures in torch_em.model
,
loss functions in torch_em.loss
, data pipelines in torch_em.data
, and data transformations in torch_em.transform
.
It also provides ready-to-use datasets for many different bio-medical image analysis tasks in torch_em.data.datasets
.
These datasets are explained in detail in Biomedical Datasets.
It provides inference functionality for neural networks in torch_em.util.prediction
and a function to export trained networks to BioImage.IO in torch_em.util.modelzoo
.
You can find the scripts for training 2d U-Nets and 3d U-Nets for various segmentation tasks in experiments/unet-segmentation. We also provide two example notebooks that demonstrate U-Net training in more detail:
- 2D-UNet Example Notebook to train a 2d UNet for a segmentation task. Available on google colab.
- 3D-UNet Example Notebook to train a 3d UNet for a segmentation task. Available on google colab.
Advanced Features
torch_em
also implements multiple advanced network architectures and training algorithms:
- Semi-supervised training or domain adaptation via FixMatch and MeanTeacher.
- Random forest based domain adaptation from Shallow2Deep.
- This functionality is implemented in
torch_em.shallow2deep
. - Scripts for training Shallow2Deep models are located in experiments/shallow2deep.
- This functionality is implemented in
- Training models for embedding prediction with sparse instance labels with SPOCO.
- This functionality is implemented in
torch_em.trainer.spoco_trainer
. - Scripts for training models with SPOCO are located in experiments/spoco.
- This functionality is implemented in
- Transformer-based segmentation via UNETR, with a choice of vision transformer backbones from Segment Anything or Masked Autoencoder.
- This functionality is implemented in
torch_em.model.unetr
. - Scripts for training UNETR models are located in experiments/vision-transformer/unetr.
- This functionality is implemented in
- Mamba-based segmentation via ViM-UNet.
- This functionality is implemented in
torch_em.model.vim
. - Scripts for training ViM-UNet models are located in experiments/vision-mamba/vimunet.
- This functionality is implemented in
torch_em
also enables data parallel multi-gpu training via torch.distributed
. This functionality is implemented in torch_em.multi_gpu_training
. See scripts/run_multi_gpu_train.py for an example script.
Command Line Interface
torch_em
provides the following command line scripts:
torch_em.train_unet_2d
to train a 2D U-Net.torch_em.train_unet_3d
to train a 3D U-Net.torch_em.predict
to run prediction with a trained model.torch_em.predict_with_tiling
to run prediction with tiling.torch_em.export_bioimageio_model
to export a model to the modelzoo format.torch_em.validate_checkpoint
to evaluate a model from a trainer checkpoint.
For more details run <COMMAND> -h
for any of these commands.
The folder scripts/cli contains some examples for how to use the CLI.
Projects using torch_em
Multiple research projects are built with torch_em
:
- Probabilistic Domain Adaptation for Biomedical Image Segmentation | Code Repository
- Segment Anything for Microscopy | Code Repository
- ViM-UNet: Vision Mamba for Biomedical Segmentation | Code Repository
- SynapseNet: Deep Learning for Automatic Synapse Reconstruction | Code Repository
- MedicoSAM: Towards foundation models for medical image segmentation | Code Repository
- Parameter Efficient Fine-Tuning of Segment Anything Model | Code Repository
- Segment Anything for Histopathology | Code Repository
Biomedical Datasets
We provide PyTorch Datasets / DataLoaders for many different biomedical datasets, mostly for segmentation tasks.
They are implemented in torch_em.data.datasets
. See also scripts/datasets for examples on how to visualize images from these datasets.
Available Datasets
All datasets in torch_em.data.datasets
are implemented according to the following logic:
- The function
get_..._data
downloads the respective datasets. Note that some datasets cannot be downloaded automatically. In these cases the function will raise an error with a message that explains how to manually download the data. - The function
get_..._paths
returns the filepaths to the downloaded inputs. - The function
get_..._dataset
returns the PyTorch Dataset for the corresponding dataset. - The function
get_..._dataloader
returns the PyTorch DataLoader for the corresponding dataset.
We provide ready-to-use light microscopy datasets in torch_em.data.datasets.light_microscopy
, electron microscopy datasets in torch_em.data.datasets.electron_microscopy
, histopathology datases in torch_em.data.datasets.histopathology
and medical imaging datasets in torch_em.data.datasets.medical
.
Creating your own Dataset and DataLoader
The following tutorial walks you through the steps to create a torch_em
-based dataloader for your data.
You can also find an interactive tutorial with examles in torch_em/notebooks/tutorial_data_loaders.ipynb.
To follow this tutorial you should first familiarize yourself with Datasets and DataLoaders in PyTorch, for example with the official PyTorch Tutorial.
Creating a Dataset
torch_em
offers two dataset classes for segmentation training: torch_em.data.ImageCollectionDataset
and torch_em.data.SegmentationDataset
. Both datasets require image data and segmentation data (to be used as targets for training).
The ImageCollectionDataset supports images of different sizes, but only supports regular image formats such as tif, png or jpeg, the SegmentationDataset supports images of the same size and also supports more complex data formats like hdf5 or zarr.
For an overview of the different input data supported by the two datasets see Supported Data Formats and Supported Data Structures.
The simplest way to create one of these datasets is to use the convenience function torch_em.default_segmentation_dataset
.
You can use the argument is_segmentation_dataset
to determine whether to use the SegmentationDataset (True
) or the ImageCollectionDataset (False
). If this argument is not given, the function will attempt to derive the correct Dataset type from the inputs.
Alternatively, you can also directly instantiate one of the datasets:
from torch_em.data import ImageCollectionDataset, SegmentationDataset
# 1. choice: ImageCollectionDataset
dataset = ImageCollectionDataset(
raw_image_paths=<SORTED_LIST_OF_IMAGE_PATHS>, # path to all images
label_image_paths=<SORTED_LIST_OF_LABEL_PATHS>, # path to all labels
patch_shape=<PATCH_SHAPE>, # the expected patch shape to be extracted from the image
# there are other optional parameters, see `torch_em.data.image_collection_dataset.py` for details.
)
# 2. choice: SegmentationDataset
dataset = SegmentationDataset(
raw_path=<PATH_TO_IMAGE>, # path to one image volume or multiple image volumes (of same shape)
raw_key=<IMAGE_KEY>, # the value to access images from heterogenous storage formats like zarr, hdf5, n5
label_path=<PATH_TO_LABEL>, # path to one label volume or multiple label volumes (of same shape)
label_key=<LABEL_KEY>, # the value to access labels from heterogenous storage formats like zarr, hdf5, n5
patch_shape=<PATCH_SHAPE>, # the expected patch shape to be extracted from the image
ndim=<NDIM>, # the expected dimension of your desired patches (2 for two-dimensional and 3 for three-dimensional)
# there are other optional parameters, see `torch_em.data.segmentation_dataset.py` for details.
)
Creating a DataLoader
You can use the convenience function torch_em.default_segmentation_loader
to directly create a DataLoader.
It will call torch_em.default_segmentation_dataset
internally.
Alternatively, you can also create a DataLoader from a Dataset object, for example one you have created following the steps outlined in the previous section:
from torch_em.segmentation import get_data_loader
dataset = ...
loader = get_data_loader(
dataset=dataset,
batch_size=batch_size,
# there are other optional parameters, which work the same as for `torch.utils.data.DataLoader`.
# feel free to pass them with the PyTorch convention, they should work fine.
# e.g. `shuffle=True`, `num_workers=16`, etc.
)
You can now use the DataLoader for training your model, either with torch_em.default_segmentation_trainer
or with any other PyTorch-based training logic.
Supported Data Formats
torch_em
uses elf and imageio to read image data formats.
It thus can open files in the formats supported by elf
: Zarr (.zarr
), NIFTI (.nii
, .nii.gz
), HDF5 (.h5
, .hdf5
), N5 (.n5
) and MRC (.mrc
), and the formats supported by imageio (.tif
, .png
, .jpg
, etc.).
torch_em.data.SegmentationDataset
supports all of these formats, whereas torch_em.data.ImageCollectionDataset
only support sthe regular image formats that can be opened with imageio
.
Supported Data Structures
The shapes given in the following are illustative examples.
torch_em
Datasets and DataLoaders can be created for:
- 2d images:
- Single-channel inputs of:
- the same size (all images have the same shape, e.g. (256, 256))
- use
torch_em.data.SegmentationDataset
(recommended) ortorch_em.data.ImageCollectionDataset
- use
- different sizes (images have different shapes, e.g. (256, 256), (378, 378), (512, 512), etc.)
- use
torch_em.data.ImageCollectionDataset
- use
- the same size (all images have the same shape, e.g. (256, 256))
- Multi-channel inputs of:
- the same same size (i.e. all images have the same shape)
- use
torch_em.data.SegmentationDataset
(recommended for inputs with channels first, see below) ortorch_em.data.ImageCollectionDataset
(for inputs in RGB format)
- use
- different sizes (i.e. images have different shapes, e.g. (3, 256, 256), (3, 378, 378), (3, 512, 512), etc.)
- use
torch_em.data.ImageCollectionDataset
- use
- Note: multi-channel inputs are supported best if they have the channel dimension as first axis, (e.g. RGB format -> (256, 256, 3) to channels-first format -> (3, 256, 256)). In order to handle inputs with channel-last / RGB format you can:
- CASE 1: Keep the inputs in RGB format and use
torch_em.data.ImageCollectionDataset
(oris_seg_dataset=False
). - CASE 2: Convert the inputs to channels-first, see the instructions below.
- CASE 1: Keep the inputs in RGB format and use
- the same same size (i.e. all images have the same shape)
- Single-channel inputs of:
- 3d images
- Single-channel inputs of:
- the same size (all volumes have the same shape, e.g. (100, 256, 256))
- use
torch_em.data.SegmentationDataset
- use
- the same size per slice with a different number of slices (volumes have shapes like (100, 256, 256), (100, 256, 256), (100, 256, 256), etc.)
- use an individual
torch_em.data.SegmentationDataset
per volume.
- use an individual
- different sizes (volumes have shapes like (100, 256, 256), (200, 378, 378), (300, 512, 512), etc.)
- use an individual
torch_em.data.SegmentationDataset
per volume.
- use an individual
- the same size (all volumes have the same shape, e.g. (100, 256, 256))
- Multi-channel inputs of:
- the same size (all volumes have the same shape, e.g. (100, 3, 256, 256))
- use
torch_em.data.SegmentationDataset
- use
- the same size per slice with a different number of slices (volumes have shapes like (100, 3, 256, 256), (100, 3, 256, 256), (100, 3, 256, 256), etc.)
- use an individual
torch_em.data.SegmentationDataset
per volume.
- use an individual
- different sizes (volumes have shapes like (100, 3, 256, 256), (200, 2, 378, 378), (300, 4, 512, 512), etc.)
- use an individual
torch_em.data.SegmentationDataset
per volume.
- use an individual
- the same size (all volumes have the same shape, e.g. (100, 3, 256, 256))
- Single-channel inputs of:
You can create a combined dataset out of multiple individual datasets using torch_em.data.ConcatDataset
.
You can also use torch_em.default_segmentation_dataset
/ torch_em.default_segmentation_loader
and pass
a list of file paths to the raw_paths
and label_paths
arguments.
This will create multiple datasets internally and then combine them.
Note:
- If your data isn't according to one of the suggested data formats, the DataLoader creation probably won't work. It's recommended to convert the data into one of the currently supported data structures (we recommend using Zarr / HDF5 / N5 for this purpose) and then move ahead.
- If your data isn't according to one of the supported data structures, the data loader might stil work, but you will run into issues leater, leading to incorrect formatting of inputs in your dataloader.
- If you have suggestions (or requests) for additional data formats or data structures, let us know here.
Further Recommendations
- Most of the open-source datasets come with their recommended train / val / test splits. In that case, the best practice is to create a function to automatically create the dataset / dataloader for all three splits for you (see
torch_em.data.datasets.dynamicnuclearnet
for an example) (OR, create three datasets / dataloader one after the other). - Some datasets offer a training set and a test set. The best practice is create a "balanced" split internally (for train and val, if desired) and then create the datasets / dataloaders.
- Some datasets offer only one set of inputs for developing models. There are multiple ways to handle this case, either extend in the direction of
2.
, or design your own heuristic for your use-case. - Some datasets offer only training images (without any form of labels). In this case, you could use
torch_em.data.RawDataset
ortorch_em.data.RawImageCollectionDataset
, see for exampletorch_em.data.datasets.neurips_cell_seg.get_neurips_cellseg_unsupervised_dataset
. Below is a small snippet showing how to use the RawImageCollectionDataset.
from torch_em.data import RawImageCollectionDataset
dataset = RawImageCollectionDataset(
raw_image_paths=<LIST_TO_IMAGE_PATHS>, # path to all images
# there are other optional parameters, see `torch_em.data.raw_image_collection_dataset.py` for details.
)