torch_em.transform.nnunet_raw
1import json 2from typing import Union 3import numpy as np 4 5 6class nnUNetRawTransform: 7 """Apply nnUNet transformation to raw inptus. 8 9 Adapted from nnUNetv2's `ImageNormalization`: 10 https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization 11 12 You can use this class to apply the necessary raw transformations on input modalities. 13 Currently supports transformations for CT and PET data. The inputs should be of dimension 2 * (H * W * D). 14 - The first channel should be CT volume 15 - The second channel should be PET volume 16 17 Here's an example for how to use this class: 18 ```python 19 # Initialize the raw transform. 20 raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json") 21 22 # Apply transformation on the inputs. 23 patient_vol = np.concatenate(ct_vol, pet_vol) 24 patient_transformed = raw_transform(patient_vol) 25 ``` 26 27 Args: 28 plans_file: The file with the nnUNet data plan. 29 expected_dtype: The expected dtype of the input data. 30 tolerance: The numerical tolerance. 31 model_name: The name of the model. 32 """ 33 def __init__( 34 self, 35 plans_file: str, 36 expected_dtype: Union[np.dtype, str] = np.float32, 37 tolerance: float = 1e-8, 38 model_name: str = "3d_fullres", 39 ): 40 self.expected_dtype = expected_dtype 41 self.tolerance = tolerance 42 43 json_file = self.load_json(plans_file) 44 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 45 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"] 46 47 def load_json(self, _file: str): 48 """@private 49 """ 50 # source: `batchgenerators.utilities.file_and_folder_operations` 51 with open(_file, "r") as f: 52 a = json.load(f) 53 return a 54 55 def ct_transform(self, channel, properties): 56 """@private 57 """ 58 mean = properties['mean'] 59 std = properties['std'] 60 lower_bound = properties['percentile_00_5'] 61 upper_bound = properties['percentile_99_5'] 62 63 transformed_channel = np.clip(channel, lower_bound, upper_bound) 64 transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) 65 return transformed_channel 66 67 def __call__(self, raw: np.ndarray) -> np.ndarray: 68 """Returns the raw inputs after applying the pre-processing from nnUNet. 69 70 Args: 71 raw: The raw array inputs. Expectd a float array of shape M * (H * W * D), with M = number of modalities. 72 73 Returns: 74 The transformed raw inputs (the same shape as inputs). 75 """ 76 assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" 77 78 raw = raw.astype(self.expected_dtype) 79 80 normalized_channels = [] 81 for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): 82 properties = self.intensity_properties[str(idxx)] 83 84 # get the correct transformation function, this can for example be a method of this class 85 if channel_transform == "CTNormalization": 86 channel = self.ct_transform(channel, properties) 87 elif channel_transform in [ 88 "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization" 89 ]: 90 raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.") 91 else: 92 raise ValueError(f"Transform is not known: {channel_transform}.") 93 94 normalized_channels.append(channel) 95 96 return np.stack(normalized_channels)
class
nnUNetRawTransform:
7class nnUNetRawTransform: 8 """Apply nnUNet transformation to raw inptus. 9 10 Adapted from nnUNetv2's `ImageNormalization`: 11 https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization 12 13 You can use this class to apply the necessary raw transformations on input modalities. 14 Currently supports transformations for CT and PET data. The inputs should be of dimension 2 * (H * W * D). 15 - The first channel should be CT volume 16 - The second channel should be PET volume 17 18 Here's an example for how to use this class: 19 ```python 20 # Initialize the raw transform. 21 raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json") 22 23 # Apply transformation on the inputs. 24 patient_vol = np.concatenate(ct_vol, pet_vol) 25 patient_transformed = raw_transform(patient_vol) 26 ``` 27 28 Args: 29 plans_file: The file with the nnUNet data plan. 30 expected_dtype: The expected dtype of the input data. 31 tolerance: The numerical tolerance. 32 model_name: The name of the model. 33 """ 34 def __init__( 35 self, 36 plans_file: str, 37 expected_dtype: Union[np.dtype, str] = np.float32, 38 tolerance: float = 1e-8, 39 model_name: str = "3d_fullres", 40 ): 41 self.expected_dtype = expected_dtype 42 self.tolerance = tolerance 43 44 json_file = self.load_json(plans_file) 45 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 46 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"] 47 48 def load_json(self, _file: str): 49 """@private 50 """ 51 # source: `batchgenerators.utilities.file_and_folder_operations` 52 with open(_file, "r") as f: 53 a = json.load(f) 54 return a 55 56 def ct_transform(self, channel, properties): 57 """@private 58 """ 59 mean = properties['mean'] 60 std = properties['std'] 61 lower_bound = properties['percentile_00_5'] 62 upper_bound = properties['percentile_99_5'] 63 64 transformed_channel = np.clip(channel, lower_bound, upper_bound) 65 transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) 66 return transformed_channel 67 68 def __call__(self, raw: np.ndarray) -> np.ndarray: 69 """Returns the raw inputs after applying the pre-processing from nnUNet. 70 71 Args: 72 raw: The raw array inputs. Expectd a float array of shape M * (H * W * D), with M = number of modalities. 73 74 Returns: 75 The transformed raw inputs (the same shape as inputs). 76 """ 77 assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" 78 79 raw = raw.astype(self.expected_dtype) 80 81 normalized_channels = [] 82 for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): 83 properties = self.intensity_properties[str(idxx)] 84 85 # get the correct transformation function, this can for example be a method of this class 86 if channel_transform == "CTNormalization": 87 channel = self.ct_transform(channel, properties) 88 elif channel_transform in [ 89 "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization" 90 ]: 91 raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.") 92 else: 93 raise ValueError(f"Transform is not known: {channel_transform}.") 94 95 normalized_channels.append(channel) 96 97 return np.stack(normalized_channels)
Apply nnUNet transformation to raw inptus.
Adapted from nnUNetv2's ImageNormalization
:
https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization
You can use this class to apply the necessary raw transformations on input modalities. Currently supports transformations for CT and PET data. The inputs should be of dimension 2 * (H * W * D).
- The first channel should be CT volume
- The second channel should be PET volume
Here's an example for how to use this class:
# Initialize the raw transform.
raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json")
# Apply transformation on the inputs.
patient_vol = np.concatenate(ct_vol, pet_vol)
patient_transformed = raw_transform(patient_vol)
Arguments:
- plans_file: The file with the nnUNet data plan.
- expected_dtype: The expected dtype of the input data.
- tolerance: The numerical tolerance.
- model_name: The name of the model.
nnUNetRawTransform( plans_file: str, expected_dtype: Union[numpy.dtype, str] = <class 'numpy.float32'>, tolerance: float = 1e-08, model_name: str = '3d_fullres')
34 def __init__( 35 self, 36 plans_file: str, 37 expected_dtype: Union[np.dtype, str] = np.float32, 38 tolerance: float = 1e-8, 39 model_name: str = "3d_fullres", 40 ): 41 self.expected_dtype = expected_dtype 42 self.tolerance = tolerance 43 44 json_file = self.load_json(plans_file) 45 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 46 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"]