torch_em.transform.nnunet_raw

 1import json
 2import numpy as np
 3
 4
 5class nnUNetRawTransform:
 6    """Apply transformation on the raw inputs.
 7    Adapted from nnUNetv2's `ImageNormalization`:
 8        - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization
 9
10    You can use this class to apply the necessary raw transformations on input modalities.
11
12    (Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D).
13        - The first channel should be CT volume
14        - The second channel should be PET volume
15
16    Here's an example for how to use this class:
17    ```python
18    # Initialize the raw transform.
19    raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json")
20
21    # Apply transformation on the inputs.
22    patient_vol = np.concatenate(ct_vol, pet_vol)
23    patient_transformed = raw_transform(patient_vol)
24    ```
25    """
26    def __init__(
27            self,
28            plans_file: str,
29            expected_dtype: type = np.float32,
30            tolerance: float = 1e-8,
31            model_name: str = "3d_fullres"
32    ):
33        self.expected_dtype = expected_dtype
34        self.tolerance = tolerance
35
36        json_file = self.load_json(plans_file)
37        self.intensity_properties = json_file["foreground_intensity_properties_per_channel"]
38        self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"]
39
40    def load_json(self, _file: str):
41        # source: `batchgenerators.utilities.file_and_folder_operations`
42        with open(_file, 'r') as f:
43            a = json.load(f)
44        return a
45
46    def ct_transform(self, channel, properties):
47        mean = properties['mean']
48        std = properties['std']
49        lower_bound = properties['percentile_00_5']
50        upper_bound = properties['percentile_99_5']
51
52        transformed_channel = np.clip(channel, lower_bound, upper_bound)
53        transformed_channel = (transformed_channel - mean) / max(std, self.tolerance)
54        return transformed_channel
55
56    def __call__(
57            self,
58            raw: np.ndarray
59    ) -> np.ndarray:  # the transformed raw inputs
60        """Returns the raw inputs after applying the pre-processing from nnUNet.
61
62        Args:
63            raw: The raw array inputs
64                Expectd a float array of shape M * (H * W * D) (where, M is the number of modalities)
65        Returns:
66            The transformed raw inputs (the same shape as inputs)
67        """
68        assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match"
69
70        raw = raw.astype(self.expected_dtype)
71
72        normalized_channels = []
73        for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)):
74            properties = self.intensity_properties[str(idxx)]
75
76            # get the correct transformation function, this can for example be a method of this class
77            if channel_transform == "CTNormalization":
78                channel = self.ct_transform(channel, properties)
79            elif channel_transform in [
80                "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization"
81            ]:
82                raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.")
83            else:
84                raise ValueError(f"Transform is not known: {channel_transform}.")
85
86            normalized_channels.append(channel)
87
88        return np.stack(normalized_channels)
class nnUNetRawTransform:
 6class nnUNetRawTransform:
 7    """Apply transformation on the raw inputs.
 8    Adapted from nnUNetv2's `ImageNormalization`:
 9        - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization
10
11    You can use this class to apply the necessary raw transformations on input modalities.
12
13    (Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D).
14        - The first channel should be CT volume
15        - The second channel should be PET volume
16
17    Here's an example for how to use this class:
18    ```python
19    # Initialize the raw transform.
20    raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json")
21
22    # Apply transformation on the inputs.
23    patient_vol = np.concatenate(ct_vol, pet_vol)
24    patient_transformed = raw_transform(patient_vol)
25    ```
26    """
27    def __init__(
28            self,
29            plans_file: str,
30            expected_dtype: type = np.float32,
31            tolerance: float = 1e-8,
32            model_name: str = "3d_fullres"
33    ):
34        self.expected_dtype = expected_dtype
35        self.tolerance = tolerance
36
37        json_file = self.load_json(plans_file)
38        self.intensity_properties = json_file["foreground_intensity_properties_per_channel"]
39        self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"]
40
41    def load_json(self, _file: str):
42        # source: `batchgenerators.utilities.file_and_folder_operations`
43        with open(_file, 'r') as f:
44            a = json.load(f)
45        return a
46
47    def ct_transform(self, channel, properties):
48        mean = properties['mean']
49        std = properties['std']
50        lower_bound = properties['percentile_00_5']
51        upper_bound = properties['percentile_99_5']
52
53        transformed_channel = np.clip(channel, lower_bound, upper_bound)
54        transformed_channel = (transformed_channel - mean) / max(std, self.tolerance)
55        return transformed_channel
56
57    def __call__(
58            self,
59            raw: np.ndarray
60    ) -> np.ndarray:  # the transformed raw inputs
61        """Returns the raw inputs after applying the pre-processing from nnUNet.
62
63        Args:
64            raw: The raw array inputs
65                Expectd a float array of shape M * (H * W * D) (where, M is the number of modalities)
66        Returns:
67            The transformed raw inputs (the same shape as inputs)
68        """
69        assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match"
70
71        raw = raw.astype(self.expected_dtype)
72
73        normalized_channels = []
74        for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)):
75            properties = self.intensity_properties[str(idxx)]
76
77            # get the correct transformation function, this can for example be a method of this class
78            if channel_transform == "CTNormalization":
79                channel = self.ct_transform(channel, properties)
80            elif channel_transform in [
81                "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization"
82            ]:
83                raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.")
84            else:
85                raise ValueError(f"Transform is not known: {channel_transform}.")
86
87            normalized_channels.append(channel)
88
89        return np.stack(normalized_channels)

Apply transformation on the raw inputs. Adapted from nnUNetv2's ImageNormalization: - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization

You can use this class to apply the necessary raw transformations on input modalities.

(Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D). - The first channel should be CT volume - The second channel should be PET volume

Here's an example for how to use this class:

# Initialize the raw transform.
raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json")

# Apply transformation on the inputs.
patient_vol = np.concatenate(ct_vol, pet_vol)
patient_transformed = raw_transform(patient_vol)
nnUNetRawTransform( plans_file: str, expected_dtype: type = <class 'numpy.float32'>, tolerance: float = 1e-08, model_name: str = '3d_fullres')
27    def __init__(
28            self,
29            plans_file: str,
30            expected_dtype: type = np.float32,
31            tolerance: float = 1e-8,
32            model_name: str = "3d_fullres"
33    ):
34        self.expected_dtype = expected_dtype
35        self.tolerance = tolerance
36
37        json_file = self.load_json(plans_file)
38        self.intensity_properties = json_file["foreground_intensity_properties_per_channel"]
39        self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"]
expected_dtype
tolerance
intensity_properties
per_channel_scheme
def load_json(self, _file: str):
41    def load_json(self, _file: str):
42        # source: `batchgenerators.utilities.file_and_folder_operations`
43        with open(_file, 'r') as f:
44            a = json.load(f)
45        return a
def ct_transform(self, channel, properties):
47    def ct_transform(self, channel, properties):
48        mean = properties['mean']
49        std = properties['std']
50        lower_bound = properties['percentile_00_5']
51        upper_bound = properties['percentile_99_5']
52
53        transformed_channel = np.clip(channel, lower_bound, upper_bound)
54        transformed_channel = (transformed_channel - mean) / max(std, self.tolerance)
55        return transformed_channel