torch_em.data.tensor_dataset

 1from typing import Callable, List, Optional, Tuple, Union
 2
 3import numpy as np
 4import torch
 5
 6from .image_collection_dataset import ImageCollectionDataset
 7
 8
 9class TensorDataset(ImageCollectionDataset):
10    """A dataset for in-memory images and segmentation labels.
11
12    The images and labels may be either numpy arrays or tensors.
13
14    Args:
15        images: The list of images.
16        labels: The list of label images.
17        label_transform: Transformation applied to the label data of a sample,
18            before applying augmentations via `transform`.
19        label_transform2: Transformation applied to the label data of a sample,
20            after applying augmentations via `transform`.
21        transform: Transformation applied to both the raw data and label data of a sample.
22            This can be used to implement data augmentations.
23        dtype: The return data type of the raw data.
24        label_dtype: The return data type of the label data.
25        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
26        sampler: Sampler for rejecting samples according to a defined criterion.
27            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
28        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
29        with_channels: Whether the raw data has channels.
30    """
31    def __init__(
32        self,
33        images: List[Union[np.ndarray, torch.Tensor]],
34        labels: List[Union[np.ndarray, torch.Tensor]],
35        patch_shape: Tuple[int, ...],
36        raw_transform: Optional[Callable] = None,
37        label_transform: Optional[Callable] = None,
38        label_transform2: Optional[Callable] = None,
39        transform: Optional[Callable] = None,
40        dtype: torch.dtype = torch.float32,
41        label_dtype: torch.dtype = torch.float32,
42        n_samples: Optional[int] = None,
43        sampler: Optional[Callable] = None,
44        with_padding: bool = True,
45        with_channels: bool = False,
46    ) -> None:
47        self.raw_images = images
48        self.label_images = labels
49        self.patch_shape = patch_shape
50        self.with_channels = with_channels
51        self._check_inputs()
52        self._ndim = len(self.patch_shape)
53
54        self.with_label_channels = False
55        self.have_tensor_data = True
56
57        self.raw_transform = raw_transform
58        self.label_transform = label_transform
59        self.label_transform2 = label_transform2
60        self.transform = transform
61        self.sampler = sampler
62        self.with_padding = with_padding
63
64        self.dtype = dtype
65        self.label_dtype = label_dtype
66
67        if n_samples is None:
68            self._len = len(self.raw_images)
69            self.sample_random_index = False
70        else:
71            self._len = n_samples
72            self.sample_random_index = True
73
74    def _check_inputs(self):
75        ndim = len(self.patch_shape)
76        if len(self.raw_images) != len(self.label_images):
77            raise ValueError(
78                f"Number of images and labels does not match: {len(self.raw_images)}, {len(self.label_images)}"
79            )
80        for image, labels in zip(self.raw_images, self.label_images):
81            im_shape = image.shape
82            if self.with_channels and len(im_shape) != ndim + 1:
83                raise ValueError("Image shape does not match the patch shape")
84            elif not self.with_channels and len(im_shape) != ndim:
85                raise ValueError("Image shape does not match the patch shape")
86
87            if self.with_channels and im_shape[1:] != labels.shape:
88                raise ValueError("Image and label shape does not match")
89            elif not self.with_channels and im_shape != labels.shape:
90                raise ValueError("Image and label shape does not match")
class TensorDataset(typing.Generic[+_T_co]):
10class TensorDataset(ImageCollectionDataset):
11    """A dataset for in-memory images and segmentation labels.
12
13    The images and labels may be either numpy arrays or tensors.
14
15    Args:
16        images: The list of images.
17        labels: The list of label images.
18        label_transform: Transformation applied to the label data of a sample,
19            before applying augmentations via `transform`.
20        label_transform2: Transformation applied to the label data of a sample,
21            after applying augmentations via `transform`.
22        transform: Transformation applied to both the raw data and label data of a sample.
23            This can be used to implement data augmentations.
24        dtype: The return data type of the raw data.
25        label_dtype: The return data type of the label data.
26        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
27        sampler: Sampler for rejecting samples according to a defined criterion.
28            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
29        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
30        with_channels: Whether the raw data has channels.
31    """
32    def __init__(
33        self,
34        images: List[Union[np.ndarray, torch.Tensor]],
35        labels: List[Union[np.ndarray, torch.Tensor]],
36        patch_shape: Tuple[int, ...],
37        raw_transform: Optional[Callable] = None,
38        label_transform: Optional[Callable] = None,
39        label_transform2: Optional[Callable] = None,
40        transform: Optional[Callable] = None,
41        dtype: torch.dtype = torch.float32,
42        label_dtype: torch.dtype = torch.float32,
43        n_samples: Optional[int] = None,
44        sampler: Optional[Callable] = None,
45        with_padding: bool = True,
46        with_channels: bool = False,
47    ) -> None:
48        self.raw_images = images
49        self.label_images = labels
50        self.patch_shape = patch_shape
51        self.with_channels = with_channels
52        self._check_inputs()
53        self._ndim = len(self.patch_shape)
54
55        self.with_label_channels = False
56        self.have_tensor_data = True
57
58        self.raw_transform = raw_transform
59        self.label_transform = label_transform
60        self.label_transform2 = label_transform2
61        self.transform = transform
62        self.sampler = sampler
63        self.with_padding = with_padding
64
65        self.dtype = dtype
66        self.label_dtype = label_dtype
67
68        if n_samples is None:
69            self._len = len(self.raw_images)
70            self.sample_random_index = False
71        else:
72            self._len = n_samples
73            self.sample_random_index = True
74
75    def _check_inputs(self):
76        ndim = len(self.patch_shape)
77        if len(self.raw_images) != len(self.label_images):
78            raise ValueError(
79                f"Number of images and labels does not match: {len(self.raw_images)}, {len(self.label_images)}"
80            )
81        for image, labels in zip(self.raw_images, self.label_images):
82            im_shape = image.shape
83            if self.with_channels and len(im_shape) != ndim + 1:
84                raise ValueError("Image shape does not match the patch shape")
85            elif not self.with_channels and len(im_shape) != ndim:
86                raise ValueError("Image shape does not match the patch shape")
87
88            if self.with_channels and im_shape[1:] != labels.shape:
89                raise ValueError("Image and label shape does not match")
90            elif not self.with_channels and im_shape != labels.shape:
91                raise ValueError("Image and label shape does not match")

A dataset for in-memory images and segmentation labels.

The images and labels may be either numpy arrays or tensors.

Arguments:
  • images: The list of images.
  • labels: The list of label images.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • with_channels: Whether the raw data has channels.
TensorDataset( images: List[Union[numpy.ndarray, torch.Tensor]], labels: List[Union[numpy.ndarray, torch.Tensor]], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, with_padding: bool = True, with_channels: bool = False)
32    def __init__(
33        self,
34        images: List[Union[np.ndarray, torch.Tensor]],
35        labels: List[Union[np.ndarray, torch.Tensor]],
36        patch_shape: Tuple[int, ...],
37        raw_transform: Optional[Callable] = None,
38        label_transform: Optional[Callable] = None,
39        label_transform2: Optional[Callable] = None,
40        transform: Optional[Callable] = None,
41        dtype: torch.dtype = torch.float32,
42        label_dtype: torch.dtype = torch.float32,
43        n_samples: Optional[int] = None,
44        sampler: Optional[Callable] = None,
45        with_padding: bool = True,
46        with_channels: bool = False,
47    ) -> None:
48        self.raw_images = images
49        self.label_images = labels
50        self.patch_shape = patch_shape
51        self.with_channels = with_channels
52        self._check_inputs()
53        self._ndim = len(self.patch_shape)
54
55        self.with_label_channels = False
56        self.have_tensor_data = True
57
58        self.raw_transform = raw_transform
59        self.label_transform = label_transform
60        self.label_transform2 = label_transform2
61        self.transform = transform
62        self.sampler = sampler
63        self.with_padding = with_padding
64
65        self.dtype = dtype
66        self.label_dtype = label_dtype
67
68        if n_samples is None:
69            self._len = len(self.raw_images)
70            self.sample_random_index = False
71        else:
72            self._len = n_samples
73            self.sample_random_index = True
raw_images
label_images
patch_shape
with_channels
with_label_channels
have_tensor_data
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
dtype
label_dtype