torch_em.shallow2deep.pseudolabel_training

  1import os
  2from typing import Callable, Dict, Optional, Sequence, Tuple, Union
  3
  4import torch
  5from torch_em.data import ConcatDataset, PseudoLabelDataset
  6from torch_em.segmentation import (
  7    get_data_loader, get_raw_transform, is_segmentation_dataset, samples_to_datasets, _get_default_transform
  8)
  9from .shallow2deep_model import Shallow2DeepModel
 10
 11
 12def check_paths(raw_paths):
 13    """@private
 14    """
 15    def _check_path(path):
 16        if not os.path.exists(path):
 17            raise ValueError(f"Could not find path {path}")
 18
 19    if isinstance(raw_paths, str):
 20        _check_path(raw_paths)
 21    else:
 22        for rp in raw_paths:
 23            _check_path(rp)
 24
 25
 26def _load_pseudolabel_dataset(raw_paths, raw_key, **kwargs):
 27    rois = kwargs.pop("rois", None)
 28    if isinstance(raw_paths, str):
 29        if rois is not None:
 30            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
 31        ds = PseudoLabelDataset(raw_paths, raw_key, roi=rois, labeler_device="cpu", **kwargs)
 32    else:
 33        assert len(raw_paths) > 0
 34        if rois is not None:
 35            assert len(rois) == len(raw_paths), f"{len(rois)}, {len(raw_paths)}"
 36            assert all(isinstance(roi, tuple) for roi in rois)
 37        n_samples = kwargs.pop("n_samples", None)
 38
 39        samples_per_ds = (
 40            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
 41        )
 42        ds = []
 43        for i, raw_path in enumerate(raw_paths):
 44            roi = None if rois is None else rois[i]
 45            dset = PseudoLabelDataset(
 46                raw_path, raw_key, roi=roi, labeler_device="cpu", n_samples=samples_per_ds[i], **kwargs
 47            )
 48            ds.append(dset)
 49        ds = ConcatDataset(*ds)
 50    return ds
 51
 52
 53def get_pseudolabel_dataset(
 54    raw_paths: Union[str, Sequence[str]],
 55    raw_key: Optional[str],
 56    checkpoint: str,
 57    rf_config: Dict,
 58    patch_shape: Tuple[int, ...],
 59    raw_transform: Optional[Callable] = None,
 60    transform: Optional[Callable] = None,
 61    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
 62    n_samples: Optional[int] = None,
 63    ndim: Optional[int] = None,
 64    is_raw_dataset: Optional[bool] = None,
 65    pseudo_labeler_device: str = "cpu",
 66) -> torch.utils.data.Dataset:
 67    """Get a pseudo-label dataset for training from a Shallow2Deep model.
 68
 69    Args:
 70        raw_paths: The raw paths for training the model. May also be a single file.
 71        raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
 72        checkpoint: The checkpoint for the trained Shallow2Deep model.
 73        rf_config: The configuration for the random forest used for the Shallow2Deep model.
 74        patch_shape: The patch shape for training.
 75        raw_transform: The transformation to apply to the raw data.
 76        transform: The transformation to implement augmentations.
 77        rois: The region of interest for the training data.
 78        n_samples: The length of this dataset.
 79        ndim: The dimensionality of the dataset.
 80        is_raw_dataset: Whether this is a segmentation or image collection dataset.
 81            If None, will be derived from the data.
 82        pseudo_labeler_device: The device for the pseudo labeling model.
 83
 84    Returns:
 85        The dataset with pseudo-labeler.
 86    """
 87    check_paths(raw_paths)
 88    if is_raw_dataset is None:
 89        is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key)
 90
 91    # we always use a raw transform in the convenience function
 92    if raw_transform is None:
 93        raw_transform = get_raw_transform()
 94
 95    # we always use augmentations in the convenience function
 96    if transform is None:
 97        transform = _get_default_transform(
 98            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim
 99        )
100
101    pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device)
102    if is_raw_dataset:
103        ds = _load_pseudolabel_dataset(
104            raw_paths, raw_key,
105            patch_shape=patch_shape,
106            pseudo_labeler=pseudo_labeler,
107            raw_transform=raw_transform,
108            transform=transform,
109            rois=rois, n_samples=n_samples, ndim=ndim,
110        )
111    else:
112        raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.")
113    return ds
114
115
116# TODO add options for confidence module and consistency
117def get_pseudolabel_loader(
118    raw_paths: Union[str, Sequence[str]],
119    raw_key: Optional[str],
120    checkpoint: str,
121    rf_config: Dict,
122    batch_size: int,
123    patch_shape: Tuple[int, ...],
124    raw_transform: Optional[Callable] = None,
125    transform: Optional[Callable] = None,
126    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
127    n_samples: Optional[int] = None,
128    ndim: Optional[int] = None,
129    is_raw_dataset: Optional[bool] = None,
130    pseudo_labeler_device: str = "cpu",
131    **loader_kwargs,
132) -> torch.utils.data.DataLoader:
133    """Get a pseudo-label dataloader for training from a Shallow2Deep model.
134
135    Args:
136        raw_paths: The raw paths for training the model. May also be a single file.
137        raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
138        checkpoint: The checkpoint for the trained Shallow2Deep model.
139        rf_config: The configuration for the random forest used for the Shallow2Deep model.
140        batch_size: The batch size for the data loader.
141        patch_shape: The patch shape for training.
142        raw_transform: The transformation to apply to the raw data.
143        transform: The transformation to implement augmentations.
144        rois: The region of interest for the training data.
145        n_samples: The length of this dataset.
146        ndim: The dimensionality of the dataset.
147        is_raw_dataset: Whether this is a segmentation or image collection dataset.
148            If None, will be derived from the data.
149        pseudo_labeler_device: The device for the pseudo labeling model.
150        loader_kwargs: Keyword arguments for the data loader.
151
152    Returns:
153        The dataloader with pseudo-labeler.
154    """
155    ds = get_pseudolabel_dataset(
156        raw_paths=raw_paths, raw_key=raw_key,
157        checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape,
158        raw_transform=raw_transform, transform=transform, rois=rois,
159        n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset,
160        pseudo_labeler_device=pseudo_labeler_device,
161    )
162    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def get_pseudolabel_dataset( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], checkpoint: str, rf_config: Dict, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, ndim: Optional[int] = None, is_raw_dataset: Optional[bool] = None, pseudo_labeler_device: str = 'cpu') -> torch.utils.data.dataset.Dataset:
 54def get_pseudolabel_dataset(
 55    raw_paths: Union[str, Sequence[str]],
 56    raw_key: Optional[str],
 57    checkpoint: str,
 58    rf_config: Dict,
 59    patch_shape: Tuple[int, ...],
 60    raw_transform: Optional[Callable] = None,
 61    transform: Optional[Callable] = None,
 62    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
 63    n_samples: Optional[int] = None,
 64    ndim: Optional[int] = None,
 65    is_raw_dataset: Optional[bool] = None,
 66    pseudo_labeler_device: str = "cpu",
 67) -> torch.utils.data.Dataset:
 68    """Get a pseudo-label dataset for training from a Shallow2Deep model.
 69
 70    Args:
 71        raw_paths: The raw paths for training the model. May also be a single file.
 72        raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
 73        checkpoint: The checkpoint for the trained Shallow2Deep model.
 74        rf_config: The configuration for the random forest used for the Shallow2Deep model.
 75        patch_shape: The patch shape for training.
 76        raw_transform: The transformation to apply to the raw data.
 77        transform: The transformation to implement augmentations.
 78        rois: The region of interest for the training data.
 79        n_samples: The length of this dataset.
 80        ndim: The dimensionality of the dataset.
 81        is_raw_dataset: Whether this is a segmentation or image collection dataset.
 82            If None, will be derived from the data.
 83        pseudo_labeler_device: The device for the pseudo labeling model.
 84
 85    Returns:
 86        The dataset with pseudo-labeler.
 87    """
 88    check_paths(raw_paths)
 89    if is_raw_dataset is None:
 90        is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key)
 91
 92    # we always use a raw transform in the convenience function
 93    if raw_transform is None:
 94        raw_transform = get_raw_transform()
 95
 96    # we always use augmentations in the convenience function
 97    if transform is None:
 98        transform = _get_default_transform(
 99            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim
100        )
101
102    pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device)
103    if is_raw_dataset:
104        ds = _load_pseudolabel_dataset(
105            raw_paths, raw_key,
106            patch_shape=patch_shape,
107            pseudo_labeler=pseudo_labeler,
108            raw_transform=raw_transform,
109            transform=transform,
110            rois=rois, n_samples=n_samples, ndim=ndim,
111        )
112    else:
113        raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.")
114    return ds

Get a pseudo-label dataset for training from a Shallow2Deep model.

Arguments:
  • raw_paths: The raw paths for training the model. May also be a single file.
  • raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
  • checkpoint: The checkpoint for the trained Shallow2Deep model.
  • rf_config: The configuration for the random forest used for the Shallow2Deep model.
  • patch_shape: The patch shape for training.
  • raw_transform: The transformation to apply to the raw data.
  • transform: The transformation to implement augmentations.
  • rois: The region of interest for the training data.
  • n_samples: The length of this dataset.
  • ndim: The dimensionality of the dataset.
  • is_raw_dataset: Whether this is a segmentation or image collection dataset. If None, will be derived from the data.
  • pseudo_labeler_device: The device for the pseudo labeling model.
Returns:

The dataset with pseudo-labeler.

def get_pseudolabel_loader( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], checkpoint: str, rf_config: Dict, batch_size: int, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, ndim: Optional[int] = None, is_raw_dataset: Optional[bool] = None, pseudo_labeler_device: str = 'cpu', **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
118def get_pseudolabel_loader(
119    raw_paths: Union[str, Sequence[str]],
120    raw_key: Optional[str],
121    checkpoint: str,
122    rf_config: Dict,
123    batch_size: int,
124    patch_shape: Tuple[int, ...],
125    raw_transform: Optional[Callable] = None,
126    transform: Optional[Callable] = None,
127    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
128    n_samples: Optional[int] = None,
129    ndim: Optional[int] = None,
130    is_raw_dataset: Optional[bool] = None,
131    pseudo_labeler_device: str = "cpu",
132    **loader_kwargs,
133) -> torch.utils.data.DataLoader:
134    """Get a pseudo-label dataloader for training from a Shallow2Deep model.
135
136    Args:
137        raw_paths: The raw paths for training the model. May also be a single file.
138        raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
139        checkpoint: The checkpoint for the trained Shallow2Deep model.
140        rf_config: The configuration for the random forest used for the Shallow2Deep model.
141        batch_size: The batch size for the data loader.
142        patch_shape: The patch shape for training.
143        raw_transform: The transformation to apply to the raw data.
144        transform: The transformation to implement augmentations.
145        rois: The region of interest for the training data.
146        n_samples: The length of this dataset.
147        ndim: The dimensionality of the dataset.
148        is_raw_dataset: Whether this is a segmentation or image collection dataset.
149            If None, will be derived from the data.
150        pseudo_labeler_device: The device for the pseudo labeling model.
151        loader_kwargs: Keyword arguments for the data loader.
152
153    Returns:
154        The dataloader with pseudo-labeler.
155    """
156    ds = get_pseudolabel_dataset(
157        raw_paths=raw_paths, raw_key=raw_key,
158        checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape,
159        raw_transform=raw_transform, transform=transform, rois=rois,
160        n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset,
161        pseudo_labeler_device=pseudo_labeler_device,
162    )
163    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get a pseudo-label dataloader for training from a Shallow2Deep model.

Arguments:
  • raw_paths: The raw paths for training the model. May also be a single file.
  • raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
  • checkpoint: The checkpoint for the trained Shallow2Deep model.
  • rf_config: The configuration for the random forest used for the Shallow2Deep model.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape for training.
  • raw_transform: The transformation to apply to the raw data.
  • transform: The transformation to implement augmentations.
  • rois: The region of interest for the training data.
  • n_samples: The length of this dataset.
  • ndim: The dimensionality of the dataset.
  • is_raw_dataset: Whether this is a segmentation or image collection dataset. If None, will be derived from the data.
  • pseudo_labeler_device: The device for the pseudo labeling model.
  • loader_kwargs: Keyword arguments for the data loader.
Returns:

The dataloader with pseudo-labeler.