torch_em
torch_em is a library for deep learning in microscopy images. It is built on top of PyTorch.
We are working on the documentation and will extend and improve it soon!
Datasets in torch-em
We provide PyTorch Datasets / DataLoaders for many different biomedical datasets, mostly for segmentation tasks.
They are implemented in torch_em.data.datasets
. See scripts/datasets
for examples on how to visualize images from these datasets.
Available Datasets
All datasets in torch_em.data.datasets
are implemented according to the following logic:
- The function
get_..._data
downloads the respective datasets. Note that some datasets cannot be downloaded automatically. In these cases the function will raise an error with a message that explains how to download the data. - The function
get_..._dataset
returns the PyTorch Dataset for the corresponding dataset. - The function
get_..._dataloader
returns the PyTorch DataLoader for the corresponding dataset.
Light Microscopy
We provide several light microscopy datasets. See torch_em.data.datasets.light_microscopy
for an overview.
Electron Microscopy
We provide several electron microscopy datasets. See torch_em.data.datasets.electron_microscopy
for an overview.
Histopathology
torch_em.data.datasets.histopathology
Medical Imaging
torch_em.data.datasets.medical
How to create your own dataloader?
Let's say you have a specific dataset of interest and would want to create a PyTorch supported torch-em
-based dataloader for yourself. We will walk you through how this can be done. See torch_em/notebooks/tutorial_data_loaders.ipynb
for an extensive tutorial with some examples.
Supported Data Formats
torch-em
and elf
currently support Zarr (.zarr
), NIFTI (.nii
, .nii.gz
), HDF5 (.h5
, .hdf5
), N5 (.n5
), MRC (.mrc
) and all imageio supported formats (eg. .tif
, .png
, .jpg
, etc.).
Supported Data Structures
The recommended input shapes are hinted in all the below mentioned cases as an example.
2d images
- Mono-channel inputs of:
- ✅ same size (i.e. all images have shape (256, 256), for example)
- use
SegmentationDataset
(recommended) orImageCollectionDataset
- use
- ✅ different sizes (i.e. images have shapes like (256, 256), (378, 378), (512, 512), etc., for example)
- use
ImageCollectionDataset
- use
- ✅ same size (i.e. all images have shape (256, 256), for example)
- Multi-channel inputs of:
- > NOTE: It's important to convert the images to be channels first (see above for the expected format)
- ✅ same size (i.e. all images have shape (3, 256, 256), for example)
- use
SegmentationDataset
(recommended) orImageCollectionDataset
- use
- ✅ different sizes (i.e. images have shapes like (3, 256, 256), (3, 378, 378), (3, 512, 512), etc., for example)
- use
ImageCollectionDataset
- use
- Mono-channel inputs of:
3d images
- Mono-channel inputs of:
- ✅ same size (i.e. all volumes have shape (100, 256, 256), for example)
- use
SegmentationDataset
- use
- ✅ same shape per slice with different z-stack size (i.e. volumes have shape like (100, 256, 256), (100, 256, 256), (100, 256, 256), etc., for example)
- use
SegmentationDataset
per volume
- use
- ✅ different sizes (i.e. volumes have shapes like (100, 256, 256), (200, 378, 378), (300, 512, 512), etc., for example)
- use
SegmentationDataset
per volume
- use
- ✅ same size (i.e. all volumes have shape (100, 256, 256), for example)
- Multi-channel inputs of:
- ✅ same size (i.e. all volumes have shape (100, 3, 256, 256), for example)
- use
SegmentationDataset
- use
- ✅ same shape per slice with different z-stack size (i.e. volumes have shape like (100, 3, 256, 256), (100, 3, 256, 256), (100, 3, 256, 256), etc., for example)
- use
SegmentationDataset
per volume
- use
- ✅ different sizes (i.e. volumes have shapes like (100, 3, 256, 256), (200, 2, 378, 378), (300, 4, 512, 512), etc., for example)
- use
SegmentationDataset
per volume
- use
- ✅ same size (i.e. all volumes have shape (100, 3, 256, 256), for example)
- Mono-channel inputs of:
NOTE:
- If your data isn't according to one of the suggested data formats, the data loader creation wouldn't work. It's recommended to convert the data into one of the currently supported data structures (we recommend using Zarr / HDF5 / N5 for this purpose) and then move ahead.
- If your data isn't according to one of the supported data structures, you might run into many issues, leading to incorrect formatting of inputs in your dataloader (we recommend taking a look above at
Supported Data Structures
->examples
per point). - If you have suggestions (or requests) on additional data formats or data structures, let us know here.
Create the dataset object
Once you have decided on your choice of dataset class object from above, here's an example on important parameters expected for your custom dataset.
from torch_em.data import ImageCollectionDataset, SegmentationDataset
# 1. choice: ImageCollectionDataset
dataset = ImageCollectionDataset(
raw_image_paths=<SORTED_LIST_OF_IMAGE_PATHS>, # path to all images
label_image_paths=<SORTED_LIST_OF_LABEL_PATHS>, # path to all labels
patch_shape=<PATCH_SHAPE>, # the expected patch shape to be extracted from the image
# there are other optional parameters, see `torch_em.data.image_collection_dataset.py` for details.
)
# 2. choice: SegmentationDataset
dataset = SegmentationDataset(
raw_path=<PATH_TO_IMAGE>, # path to one image volume or multiple image volumes (of same shape)
raw_key=<IMAGE_KEY>, # the value to access images from heterogenous storage formats like zarr, hdf5, n5
label_path=<PATH_TO_LABEL>, # path to one label volume or multiple label volumes (of same shape)
label_key=<LABEL_KEY>, # the value to access labels from heterogenous storage formats like zarr, hdf5, n5
patch_shape=<PATCH_SHAPE>, # the expected patch shape to be extracted from the image
ndim=<NDIM>, # the expected dimension of your desired patches (2 for two-dimensional and 3 for three-dimensional)
# there are other optional parameters, see `torch_em.data.segmentation_dataset.py` for details.
)
Create the dataloader object
Now that we have our dataset object created, let's finally create the dataloader object to start with the training.
from torch_em.segmentation import get_data_loader
dataset = ...
loader = get_data_loader(
dataset=dataset,
batch_size=batch_size,
# there are other optional parameters, which work the same as for `torch.utils.data.DataLoader`.
# feel free to pass them with the PyTorch convention, they should work fine.
# e.g. `shuffle=True`, `num_workers=16`, etc.
)
Recommendations
- Most of the open-source datasets come with their recommended train / val / test splits. In that case, the best practice is to create a function to automatically create the dataset / dataloader for all three splits for you (see
torch_em.data.datasets.dynamicnuclearnet.py
for inspiration) (OR, create three datasets / dataloader one after the other). - Some datasets offer a training set and a test set. The best practice is create a "balanced" split internally (for train and val, if desired) and then create the datasets / dataloaders.
- Some datasets offer only one set of inputs for developing models. There are multiple ways to handle this case, either extend in the direction of
2.
, or design your own heuristic for your use-case. - Some datasets offer only training images (without any form of labels). In this case, you could use
RawImageCollectionDataset
as the following (for inspiration, taketorch_em.data.datasets.neurips_cell_seg.py
->get_neurips_cellseg_unsupervised_dataset
as reference)
from torch_em.data import RawImageCollectionDataset
dataset = RawImageCollectionDataset(
raw_image_paths=<LIST_TO_IMAGE_PATHS>, # path to all images
# there are other optional parameters, see `torch_em.data.raw_image_collection_dataset.py` for details.
)