torch_em.transform.nnunet_raw
1import json 2import numpy as np 3 4 5class nnUNetRawTransform: 6 """Apply transformation on the raw inputs. 7 Adapted from nnUNetv2's `ImageNormalization`: 8 - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization 9 10 You can use this class to apply the necessary raw transformations on input modalities. 11 12 (Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D). 13 - The first channel should be CT volume 14 - The second channel should be PET volume 15 16 Here's an example for how to use this class: 17 ```python 18 # Initialize the raw transform. 19 raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json") 20 21 # Apply transformation on the inputs. 22 patient_vol = np.concatenate(ct_vol, pet_vol) 23 patient_transformed = raw_transform(patient_vol) 24 ``` 25 """ 26 def __init__( 27 self, 28 plans_file: str, 29 expected_dtype: type = np.float32, 30 tolerance: float = 1e-8, 31 model_name: str = "3d_fullres" 32 ): 33 self.expected_dtype = expected_dtype 34 self.tolerance = tolerance 35 36 json_file = self.load_json(plans_file) 37 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 38 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"] 39 40 def load_json(self, _file: str): 41 # source: `batchgenerators.utilities.file_and_folder_operations` 42 with open(_file, 'r') as f: 43 a = json.load(f) 44 return a 45 46 def ct_transform(self, channel, properties): 47 mean = properties['mean'] 48 std = properties['std'] 49 lower_bound = properties['percentile_00_5'] 50 upper_bound = properties['percentile_99_5'] 51 52 transformed_channel = np.clip(channel, lower_bound, upper_bound) 53 transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) 54 return transformed_channel 55 56 def __call__( 57 self, 58 raw: np.ndarray 59 ) -> np.ndarray: # the transformed raw inputs 60 """Returns the raw inputs after applying the pre-processing from nnUNet. 61 62 Args: 63 raw: The raw array inputs 64 Expectd a float array of shape M * (H * W * D) (where, M is the number of modalities) 65 Returns: 66 The transformed raw inputs (the same shape as inputs) 67 """ 68 assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" 69 70 raw = raw.astype(self.expected_dtype) 71 72 normalized_channels = [] 73 for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): 74 properties = self.intensity_properties[str(idxx)] 75 76 # get the correct transformation function, this can for example be a method of this class 77 if channel_transform == "CTNormalization": 78 channel = self.ct_transform(channel, properties) 79 elif channel_transform in [ 80 "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization" 81 ]: 82 raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.") 83 else: 84 raise ValueError(f"Transform is not known: {channel_transform}.") 85 86 normalized_channels.append(channel) 87 88 return np.stack(normalized_channels)
class
nnUNetRawTransform:
6class nnUNetRawTransform: 7 """Apply transformation on the raw inputs. 8 Adapted from nnUNetv2's `ImageNormalization`: 9 - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization 10 11 You can use this class to apply the necessary raw transformations on input modalities. 12 13 (Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D). 14 - The first channel should be CT volume 15 - The second channel should be PET volume 16 17 Here's an example for how to use this class: 18 ```python 19 # Initialize the raw transform. 20 raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json") 21 22 # Apply transformation on the inputs. 23 patient_vol = np.concatenate(ct_vol, pet_vol) 24 patient_transformed = raw_transform(patient_vol) 25 ``` 26 """ 27 def __init__( 28 self, 29 plans_file: str, 30 expected_dtype: type = np.float32, 31 tolerance: float = 1e-8, 32 model_name: str = "3d_fullres" 33 ): 34 self.expected_dtype = expected_dtype 35 self.tolerance = tolerance 36 37 json_file = self.load_json(plans_file) 38 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 39 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"] 40 41 def load_json(self, _file: str): 42 # source: `batchgenerators.utilities.file_and_folder_operations` 43 with open(_file, 'r') as f: 44 a = json.load(f) 45 return a 46 47 def ct_transform(self, channel, properties): 48 mean = properties['mean'] 49 std = properties['std'] 50 lower_bound = properties['percentile_00_5'] 51 upper_bound = properties['percentile_99_5'] 52 53 transformed_channel = np.clip(channel, lower_bound, upper_bound) 54 transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) 55 return transformed_channel 56 57 def __call__( 58 self, 59 raw: np.ndarray 60 ) -> np.ndarray: # the transformed raw inputs 61 """Returns the raw inputs after applying the pre-processing from nnUNet. 62 63 Args: 64 raw: The raw array inputs 65 Expectd a float array of shape M * (H * W * D) (where, M is the number of modalities) 66 Returns: 67 The transformed raw inputs (the same shape as inputs) 68 """ 69 assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" 70 71 raw = raw.astype(self.expected_dtype) 72 73 normalized_channels = [] 74 for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): 75 properties = self.intensity_properties[str(idxx)] 76 77 # get the correct transformation function, this can for example be a method of this class 78 if channel_transform == "CTNormalization": 79 channel = self.ct_transform(channel, properties) 80 elif channel_transform in [ 81 "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization" 82 ]: 83 raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.") 84 else: 85 raise ValueError(f"Transform is not known: {channel_transform}.") 86 87 normalized_channels.append(channel) 88 89 return np.stack(normalized_channels)
Apply transformation on the raw inputs.
Adapted from nnUNetv2's ImageNormalization
:
- https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization
You can use this class to apply the necessary raw transformations on input modalities.
(Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D). - The first channel should be CT volume - The second channel should be PET volume
Here's an example for how to use this class:
# Initialize the raw transform.
raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json")
# Apply transformation on the inputs.
patient_vol = np.concatenate(ct_vol, pet_vol)
patient_transformed = raw_transform(patient_vol)
nnUNetRawTransform( plans_file: str, expected_dtype: type = <class 'numpy.float32'>, tolerance: float = 1e-08, model_name: str = '3d_fullres')
27 def __init__( 28 self, 29 plans_file: str, 30 expected_dtype: type = np.float32, 31 tolerance: float = 1e-8, 32 model_name: str = "3d_fullres" 33 ): 34 self.expected_dtype = expected_dtype 35 self.tolerance = tolerance 36 37 json_file = self.load_json(plans_file) 38 self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] 39 self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"]
def
ct_transform(self, channel, properties):
47 def ct_transform(self, channel, properties): 48 mean = properties['mean'] 49 std = properties['std'] 50 lower_bound = properties['percentile_00_5'] 51 upper_bound = properties['percentile_99_5'] 52 53 transformed_channel = np.clip(channel, lower_bound, upper_bound) 54 transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) 55 return transformed_channel