torch_em.data.tensor_dataset
1from typing import Callable, List, Optional, Tuple, Union 2 3import numpy as np 4import torch 5 6from .image_collection_dataset import ImageCollectionDataset 7 8 9class TensorDataset(ImageCollectionDataset): 10 """A dataset for in-memory images and segmentation labels. 11 12 The images and labels may be either numpy arrays or tensors. 13 14 Args: 15 images: The list of images. 16 labels: The list of label images. 17 label_transform: Transformation applied to the label data of a sample, 18 before applying augmentations via `transform`. 19 label_transform2: Transformation applied to the label data of a sample, 20 after applying augmentations via `transform`. 21 transform: Transformation applied to both the raw data and label data of a sample. 22 This can be used to implement data augmentations. 23 dtype: The return data type of the raw data. 24 label_dtype: The return data type of the label data. 25 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 26 sampler: Sampler for rejecting samples according to a defined criterion. 27 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 28 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 29 with_channels: Whether the raw data has channels. 30 """ 31 def __init__( 32 self, 33 images: List[Union[np.ndarray, torch.Tensor]], 34 labels: List[Union[np.ndarray, torch.Tensor]], 35 patch_shape: Tuple[int, ...], 36 raw_transform: Optional[Callable] = None, 37 label_transform: Optional[Callable] = None, 38 label_transform2: Optional[Callable] = None, 39 transform: Optional[Callable] = None, 40 dtype: torch.dtype = torch.float32, 41 label_dtype: torch.dtype = torch.float32, 42 n_samples: Optional[int] = None, 43 sampler: Optional[Callable] = None, 44 with_padding: bool = True, 45 with_channels: bool = False, 46 ) -> None: 47 self.raw_images = images 48 self.label_images = labels 49 self.patch_shape = patch_shape 50 self.with_channels = with_channels 51 self._check_inputs() 52 self._ndim = len(self.patch_shape) 53 54 self.with_label_channels = False 55 self.have_tensor_data = True 56 57 self.raw_transform = raw_transform 58 self.label_transform = label_transform 59 self.label_transform2 = label_transform2 60 self.transform = transform 61 self.sampler = sampler 62 self.with_padding = with_padding 63 64 self.dtype = dtype 65 self.label_dtype = label_dtype 66 67 if n_samples is None: 68 self._len = len(self.raw_images) 69 self.sample_random_index = False 70 else: 71 self._len = n_samples 72 self.sample_random_index = True 73 74 def _check_inputs(self): 75 ndim = len(self.patch_shape) 76 if len(self.raw_images) != len(self.label_images): 77 raise ValueError( 78 f"Number of images and labels does not match: {len(self.raw_images)}, {len(self.label_images)}" 79 ) 80 for image, labels in zip(self.raw_images, self.label_images): 81 im_shape = image.shape 82 if self.with_channels and len(im_shape) != ndim + 1: 83 raise ValueError("Image shape does not match the patch shape") 84 elif not self.with_channels and len(im_shape) != ndim: 85 raise ValueError("Image shape does not match the patch shape") 86 87 if self.with_channels and im_shape[1:] != labels.shape: 88 raise ValueError("Image and label shape does not match") 89 elif not self.with_channels and im_shape != labels.shape: 90 raise ValueError("Image and label shape does not match")
class
TensorDataset(typing.Generic[+_T_co]):
10class TensorDataset(ImageCollectionDataset): 11 """A dataset for in-memory images and segmentation labels. 12 13 The images and labels may be either numpy arrays or tensors. 14 15 Args: 16 images: The list of images. 17 labels: The list of label images. 18 label_transform: Transformation applied to the label data of a sample, 19 before applying augmentations via `transform`. 20 label_transform2: Transformation applied to the label data of a sample, 21 after applying augmentations via `transform`. 22 transform: Transformation applied to both the raw data and label data of a sample. 23 This can be used to implement data augmentations. 24 dtype: The return data type of the raw data. 25 label_dtype: The return data type of the label data. 26 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 27 sampler: Sampler for rejecting samples according to a defined criterion. 28 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 29 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 30 with_channels: Whether the raw data has channels. 31 """ 32 def __init__( 33 self, 34 images: List[Union[np.ndarray, torch.Tensor]], 35 labels: List[Union[np.ndarray, torch.Tensor]], 36 patch_shape: Tuple[int, ...], 37 raw_transform: Optional[Callable] = None, 38 label_transform: Optional[Callable] = None, 39 label_transform2: Optional[Callable] = None, 40 transform: Optional[Callable] = None, 41 dtype: torch.dtype = torch.float32, 42 label_dtype: torch.dtype = torch.float32, 43 n_samples: Optional[int] = None, 44 sampler: Optional[Callable] = None, 45 with_padding: bool = True, 46 with_channels: bool = False, 47 ) -> None: 48 self.raw_images = images 49 self.label_images = labels 50 self.patch_shape = patch_shape 51 self.with_channels = with_channels 52 self._check_inputs() 53 self._ndim = len(self.patch_shape) 54 55 self.with_label_channels = False 56 self.have_tensor_data = True 57 58 self.raw_transform = raw_transform 59 self.label_transform = label_transform 60 self.label_transform2 = label_transform2 61 self.transform = transform 62 self.sampler = sampler 63 self.with_padding = with_padding 64 65 self.dtype = dtype 66 self.label_dtype = label_dtype 67 68 if n_samples is None: 69 self._len = len(self.raw_images) 70 self.sample_random_index = False 71 else: 72 self._len = n_samples 73 self.sample_random_index = True 74 75 def _check_inputs(self): 76 ndim = len(self.patch_shape) 77 if len(self.raw_images) != len(self.label_images): 78 raise ValueError( 79 f"Number of images and labels does not match: {len(self.raw_images)}, {len(self.label_images)}" 80 ) 81 for image, labels in zip(self.raw_images, self.label_images): 82 im_shape = image.shape 83 if self.with_channels and len(im_shape) != ndim + 1: 84 raise ValueError("Image shape does not match the patch shape") 85 elif not self.with_channels and len(im_shape) != ndim: 86 raise ValueError("Image shape does not match the patch shape") 87 88 if self.with_channels and im_shape[1:] != labels.shape: 89 raise ValueError("Image and label shape does not match") 90 elif not self.with_channels and im_shape != labels.shape: 91 raise ValueError("Image and label shape does not match")
A dataset for in-memory images and segmentation labels.
The images and labels may be either numpy arrays or tensors.
Arguments:
- images: The list of images.
- labels: The list of label images.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - with_channels: Whether the raw data has channels.
TensorDataset( images: List[Union[numpy.ndarray, torch.Tensor]], labels: List[Union[numpy.ndarray, torch.Tensor]], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, with_padding: bool = True, with_channels: bool = False)
32 def __init__( 33 self, 34 images: List[Union[np.ndarray, torch.Tensor]], 35 labels: List[Union[np.ndarray, torch.Tensor]], 36 patch_shape: Tuple[int, ...], 37 raw_transform: Optional[Callable] = None, 38 label_transform: Optional[Callable] = None, 39 label_transform2: Optional[Callable] = None, 40 transform: Optional[Callable] = None, 41 dtype: torch.dtype = torch.float32, 42 label_dtype: torch.dtype = torch.float32, 43 n_samples: Optional[int] = None, 44 sampler: Optional[Callable] = None, 45 with_padding: bool = True, 46 with_channels: bool = False, 47 ) -> None: 48 self.raw_images = images 49 self.label_images = labels 50 self.patch_shape = patch_shape 51 self.with_channels = with_channels 52 self._check_inputs() 53 self._ndim = len(self.patch_shape) 54 55 self.with_label_channels = False 56 self.have_tensor_data = True 57 58 self.raw_transform = raw_transform 59 self.label_transform = label_transform 60 self.label_transform2 = label_transform2 61 self.transform = transform 62 self.sampler = sampler 63 self.with_padding = with_padding 64 65 self.dtype = dtype 66 self.label_dtype = label_dtype 67 68 if n_samples is None: 69 self._len = len(self.raw_images) 70 self.sample_random_index = False 71 else: 72 self._len = n_samples 73 self.sample_random_index = True