torch_em.model.unet
1from typing import List, Optional, Union 2 3import numpy as np 4import torch 5import torch.nn as nn 6 7 8# 9# Model Internal Post-processing 10# 11# Note: these are mainly for bioimage.io models, where postprocessing has to be done 12# inside of the model unless its defined in the general spec 13 14 15class AccumulateChannels(nn.Module): 16 """@private 17 """ 18 def __init__( 19 self, 20 invariant_channels, 21 accumulate_channels, 22 accumulator 23 ): 24 super().__init__() 25 self.invariant_channels = invariant_channels 26 self.accumulate_channels = accumulate_channels 27 assert accumulator in ("mean", "min", "max") 28 self.accumulator = getattr(torch, accumulator) 29 30 def _accumulate(self, x, c0, c1): 31 res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True) 32 if not torch.is_tensor(res): 33 res = res.values 34 assert torch.is_tensor(res) 35 return res 36 37 def forward(self, x): 38 if self.invariant_channels is None: 39 c0, c1 = self.accumulate_channels 40 return self._accumulate(x, c0, c1) 41 else: 42 i0, i1 = self.invariant_channels 43 c0, c1 = self.accumulate_channels 44 return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1) 45 46 47def affinities_to_boundaries(aff_channels, accumulator="max"): 48 """@private 49 """ 50 return AccumulateChannels(None, aff_channels, accumulator) 51 52 53def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator="max"): 54 """@private 55 """ 56 return AccumulateChannels(fg_channel, aff_channels, accumulator) 57 58 59def affinities_to_boundaries2d(): 60 """@private 61 """ 62 return affinities_to_boundaries((0, 2)) 63 64 65def affinities_with_foreground_to_boundaries2d(): 66 """@private 67 """ 68 return affinities_with_foreground_to_boundaries((1, 3)) 69 70 71def affinities_to_boundaries3d(): 72 """@private 73 """ 74 return affinities_to_boundaries((0, 3)) 75 76 77def affinities_with_foreground_to_boundaries3d(): 78 """@private 79 """ 80 return affinities_with_foreground_to_boundaries((1, 4)) 81 82 83def affinities_to_boundaries_anisotropic(): 84 """@private 85 """ 86 return AccumulateChannels(None, (1, 3), "max") 87 88 89POSTPROCESSING = { 90 "affinities_to_boundaries_anisotropic": affinities_to_boundaries_anisotropic, 91 "affinities_to_boundaries2d": affinities_to_boundaries2d, 92 "affinities_with_foreground_to_boundaries2d": affinities_with_foreground_to_boundaries2d, 93 "affinities_to_boundaries3d": affinities_to_boundaries3d, 94 "affinities_with_foreground_to_boundaries3d": affinities_with_foreground_to_boundaries3d, 95} 96"""@private 97""" 98 99 100# 101# Base Implementations 102# 103 104class UNetBase(nn.Module): 105 """Base class for implementing a U-Net. 106 107 Args: 108 encoder: The encoder of the U-Net. 109 base: The base layer of the U-Net. 110 decoder: The decoder of the U-Net. 111 out_conv: The output convolution applied after the last decoder layer. 112 final_activation: The activation applied after the output convolution or last decoder layer. 113 postprocessing: A postprocessing function to apply after the U-Net output. 114 check_shape: Whether to check the input shape to the U-Net forward call. 115 """ 116 def __init__( 117 self, 118 encoder: nn.Module, 119 base: nn.Module, 120 decoder: nn.Module, 121 out_conv: Optional[nn.Module] = None, 122 final_activation: Optional[Union[nn.Module, str]] = None, 123 postprocessing: Optional[Union[nn.Module, str]] = None, 124 check_shape: bool = True, 125 ): 126 super().__init__() 127 if len(encoder) != len(decoder): 128 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 129 130 self.encoder = encoder 131 self.base = base 132 self.decoder = decoder 133 134 if out_conv is None: 135 self.return_decoder_outputs = False 136 self._out_channels = self.decoder.out_channels 137 elif isinstance(out_conv, nn.ModuleList): 138 if len(out_conv) != len(self.decoder): 139 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 140 self.return_decoder_outputs = True 141 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 142 else: 143 self.return_decoder_outputs = False 144 self._out_channels = out_conv.out_channels 145 self.out_conv = out_conv 146 self.check_shape = check_shape 147 self.final_activation = self._get_activation(final_activation) 148 self.postprocessing = self._get_postprocessing(postprocessing) 149 150 @property 151 def in_channels(self): 152 return self.encoder.in_channels 153 154 @property 155 def out_channels(self): 156 return self._out_channels 157 158 @property 159 def depth(self): 160 return len(self.encoder) 161 162 def _get_activation(self, activation): 163 return_activation = None 164 if activation is None: 165 return None 166 if isinstance(activation, nn.Module): 167 return activation 168 if isinstance(activation, str): 169 return_activation = getattr(nn, activation, None) 170 if return_activation is None: 171 raise ValueError(f"Invalid activation: {activation}") 172 return return_activation() 173 174 def _get_postprocessing(self, postprocessing): 175 if postprocessing is None: 176 return None 177 elif isinstance(postprocessing, nn.Module): 178 return postprocessing 179 elif postprocessing in POSTPROCESSING: 180 return POSTPROCESSING[postprocessing]() 181 else: 182 raise ValueError(f"Invalid postprocessing: {postprocessing}") 183 184 # load encoder / decoder / base states for pretraining 185 def load_encoder_state(self, state): 186 self.encoder.load_state_dict(state) 187 188 def load_decoder_state(self, state): 189 self.decoder.load_state_dict(state) 190 191 def load_base_state(self, state): 192 self.base.load_state_dict(state) 193 194 def _apply_default(self, x): 195 self.encoder.return_outputs = True 196 self.decoder.return_outputs = False 197 198 x, encoder_out = self.encoder(x) 199 x = self.base(x) 200 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 201 202 if self.out_conv is not None: 203 x = self.out_conv(x) 204 if self.final_activation is not None: 205 x = self.final_activation(x) 206 if self.postprocessing is not None: 207 x = self.postprocessing(x) 208 209 return x 210 211 def _apply_with_side_outputs(self, x): 212 self.encoder.return_outputs = True 213 self.decoder.return_outputs = True 214 215 x, encoder_out = self.encoder(x) 216 x = self.base(x) 217 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 218 219 x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)] 220 if self.final_activation is not None: 221 x = [self.final_activation(xx) for xx in x] 222 223 if self.postprocessing is not None: 224 x = [self.postprocessing(xx) for xx in x] 225 226 # we reverse the list to have the full shape output as first element 227 return x[::-1] 228 229 def _check_shape(self, x): 230 spatial_shape = tuple(x.shape)[2:] 231 depth = len(self.encoder) 232 factor = [2**depth] * len(spatial_shape) 233 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 234 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 235 raise ValueError(msg) 236 237 def forward(self, x: torch.Tensor) -> torch.tensor: 238 """Apply U-Net to input data. 239 240 Args: 241 x: The input data. 242 243 Returns: 244 The output of the U-Net. 245 """ 246 # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference. 247 # x = x.float() 248 if getattr(self, "check_shape", True): 249 self._check_shape(x) 250 if self.return_decoder_outputs: 251 return self._apply_with_side_outputs(x) 252 else: 253 return self._apply_default(x) 254 255 256def _update_conv_kwargs(kwargs, scale_factor): 257 # if the scale factor is a scalar or all entries are the same we don"t need to update the kwargs 258 if isinstance(scale_factor, int) or scale_factor.count(scale_factor[0]) == len(scale_factor): 259 return kwargs 260 else: # otherwise set anisotropic kernel 261 kernel_size = kwargs.get("kernel_size", 3) 262 padding = kwargs.get("padding", 1) 263 264 # bail out if kernel size or padding aren"t scalars, because it"s 265 # unclear what to do in this case 266 if not (isinstance(kernel_size, int) and isinstance(padding, int)): 267 return kwargs 268 269 kernel_size = tuple(1 if factor == 1 else kernel_size for factor in scale_factor) 270 padding = tuple(0 if factor == 1 else padding for factor in scale_factor) 271 kwargs.update({"kernel_size": kernel_size, "padding": padding}) 272 return kwargs 273 274 275class Encoder(nn.Module): 276 """@private 277 """ 278 def __init__( 279 self, 280 features, 281 scale_factors, 282 conv_block_impl, 283 pooler_impl, 284 anisotropic_kernel=False, 285 **conv_block_kwargs 286 ): 287 super().__init__() 288 if len(features) != len(scale_factors) + 1: 289 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 290 291 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 292 if anisotropic_kernel: 293 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 294 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 295 296 self.blocks = nn.ModuleList( 297 [conv_block_impl(inc, outc, **kwargs) 298 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 299 ) 300 self.poolers = nn.ModuleList( 301 [pooler_impl(factor) for factor in scale_factors] 302 ) 303 self.return_outputs = True 304 305 self.in_channels = features[0] 306 self.out_channels = features[-1] 307 308 def __len__(self): 309 return len(self.blocks) 310 311 def forward(self, x): 312 encoder_out = [] 313 for block, pooler in zip(self.blocks, self.poolers): 314 x = block(x) 315 encoder_out.append(x) 316 x = pooler(x) 317 318 if self.return_outputs: 319 return x, encoder_out 320 else: 321 return x 322 323 324class Decoder(nn.Module): 325 """@private 326 """ 327 def __init__( 328 self, 329 features, 330 scale_factors, 331 conv_block_impl, 332 sampler_impl, 333 anisotropic_kernel=False, 334 **conv_block_kwargs 335 ): 336 super().__init__() 337 if len(features) != len(scale_factors) + 1: 338 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 339 340 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 341 if anisotropic_kernel: 342 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 343 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 344 345 self.blocks = nn.ModuleList( 346 [conv_block_impl(inc, outc, **kwargs) 347 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 348 ) 349 self.samplers = nn.ModuleList( 350 [sampler_impl(factor, inc, outc) for factor, inc, outc 351 in zip(scale_factors, features[:-1], features[1:])] 352 ) 353 self.return_outputs = False 354 355 self.in_channels = features[0] 356 self.out_channels = features[-1] 357 358 def __len__(self): 359 return len(self.blocks) 360 361 # FIXME this prevents traces from being valid for other input sizes, need to find 362 # a solution to traceable cropping 363 def _crop(self, x, shape): 364 shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)] 365 crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)]) 366 return x[crop] 367 # # Implementation with torch.narrow, does not fix the tracing warnings! 368 # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)): 369 # x = torch.narrow(x, dim, sd, sh) 370 # return x 371 372 def _concat(self, x1, x2): 373 return torch.cat([x1, self._crop(x2, x1.shape)], dim=1) 374 375 def forward(self, x, encoder_inputs): 376 if len(encoder_inputs) != len(self.blocks): 377 raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}") 378 379 decoder_out = [] 380 for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs): 381 x = sampler(x) 382 x = block(self._concat(x, from_encoder)) 383 decoder_out.append(x) 384 385 if self.return_outputs: 386 return decoder_out + [x] 387 else: 388 return x 389 390 391def get_norm_layer(norm, dim, channels, n_groups=32): 392 """@private 393 """ 394 if norm is None: 395 return None 396 if norm == "InstanceNorm": 397 return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) 398 elif norm == "InstanceNormTrackStats": 399 kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01} 400 return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs) 401 elif norm == "GroupNorm": 402 return nn.GroupNorm(min(n_groups, channels), channels) 403 elif norm == "BatchNorm": 404 return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels) 405 else: 406 raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}") 407 408 409class ConvBlock(nn.Module): 410 """@private 411 """ 412 def __init__(self, in_channels, out_channels, dim, kernel_size=3, padding=1, norm="InstanceNorm"): 413 super().__init__() 414 self.in_channels = in_channels 415 self.out_channels = out_channels 416 417 conv = nn.Conv2d if dim == 2 else nn.Conv3d 418 419 if norm is None: 420 self.block = nn.Sequential( 421 conv(in_channels, out_channels, 422 kernel_size=kernel_size, padding=padding), 423 nn.ReLU(inplace=True), 424 conv(out_channels, out_channels, 425 kernel_size=kernel_size, padding=padding), 426 nn.ReLU(inplace=True) 427 ) 428 else: 429 self.block = nn.Sequential( 430 get_norm_layer(norm, dim, in_channels), 431 conv(in_channels, out_channels, 432 kernel_size=kernel_size, padding=padding), 433 nn.ReLU(inplace=True), 434 get_norm_layer(norm, dim, out_channels), 435 conv(out_channels, out_channels, 436 kernel_size=kernel_size, padding=padding), 437 nn.ReLU(inplace=True) 438 ) 439 440 def forward(self, x): 441 return self.block(x) 442 443 444class Upsampler(nn.Module): 445 """@private 446 """ 447 def __init__(self, scale_factor, in_channels, out_channels, dim, mode): 448 super().__init__() 449 self.mode = mode 450 self.scale_factor = scale_factor 451 452 conv = nn.Conv2d if dim == 2 else nn.Conv3d 453 self.conv = conv(in_channels, out_channels, 1) 454 455 def forward(self, x): 456 x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) 457 x = self.conv(x) 458 return x 459 460 461# 462# 2d unet implementations 463# 464 465class ConvBlock2d(ConvBlock): 466 """@private 467 """ 468 def __init__(self, in_channels, out_channels, **kwargs): 469 super().__init__(in_channels, out_channels, dim=2, **kwargs) 470 471 472class Upsampler2d(Upsampler): 473 """@private 474 """ 475 def __init__(self, scale_factor, 476 in_channels, out_channels, 477 mode="bilinear"): 478 super().__init__(scale_factor, in_channels, out_channels, dim=2, mode=mode) 479 480 481class UNet2d(UNetBase): 482 """A 2D U-Net network for segmentation and other image-to-image tasks. 483 484 The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. 485 The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers 486 per level, max-pooling for downsampling and linear interpolation for upsampling. 487 These implementations can be changed by providing arguments for `conv_block_impl`, `pooler_impl` 488 and `sampler_impl` respectively. 489 490 Args: 491 in_channels: The number of input image channels. 492 out_channels: The number of output image channels. 493 depth: The number of encoder / decoder levels of the U-Net. 494 initial_features: The initial number of features, corresponding to the features of the first conv block. 495 gain: The gain factor for increasing the features after each level. 496 final_activation: The activation applied after the output convolution or last decoder layer. 497 return_side_outputs: Whether to return the outputs after each decoder level. 498 conv_block_impl: The implementation of the convolutional block. 499 pooler_impl: The implementation of the pooling layer. 500 postprocessing: A postprocessing function to apply after the U-Net output. 501 check_shape: Whether to check the input shape to the U-Net forward call. 502 conv_block_kwargs: The keyword arguments for the convolutional block. 503 """ 504 def __init__( 505 self, 506 in_channels: int, 507 out_channels: int, 508 depth: int = 4, 509 initial_features: int = 32, 510 gain: int = 2, 511 final_activation=None, 512 return_side_outputs: bool = False, 513 conv_block_impl: nn.Module = ConvBlock2d, 514 pooler_impl: nn.Module = nn.MaxPool2d, 515 sampler_impl: nn.Module = Upsampler2d, 516 postprocessing: Optional[Union[nn.Module, str]] = None, 517 check_shape: bool = True, 518 **conv_block_kwargs, 519 ): 520 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 521 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 522 scale_factors = depth * [2] 523 524 if return_side_outputs: 525 if isinstance(out_channels, int) or out_channels is None: 526 out_channels = [out_channels] * depth 527 if len(out_channels) != depth: 528 raise ValueError() 529 out_conv = nn.ModuleList( 530 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 531 ) 532 else: 533 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 534 535 super().__init__( 536 encoder=Encoder( 537 features=features_encoder, 538 scale_factors=scale_factors, 539 conv_block_impl=conv_block_impl, 540 pooler_impl=pooler_impl, 541 **conv_block_kwargs 542 ), 543 decoder=Decoder( 544 features=features_decoder, 545 scale_factors=scale_factors[::-1], 546 conv_block_impl=conv_block_impl, 547 sampler_impl=sampler_impl, 548 **conv_block_kwargs 549 ), 550 base=conv_block_impl( 551 features_encoder[-1], features_encoder[-1] * gain, 552 **conv_block_kwargs 553 ), 554 out_conv=out_conv, 555 final_activation=final_activation, 556 postprocessing=postprocessing, 557 check_shape=check_shape, 558 ) 559 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 560 "initial_features": initial_features, "gain": gain, 561 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 562 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 563 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs} 564 565 566# 567# 3d unet implementations 568# 569 570class ConvBlock3d(ConvBlock): 571 """@private 572 """ 573 def __init__(self, in_channels, out_channels, **kwargs): 574 super().__init__(in_channels, out_channels, dim=3, **kwargs) 575 576 577class Upsampler3d(Upsampler): 578 """@private 579 """ 580 def __init__(self, scale_factor, in_channels, out_channels, mode="trilinear"): 581 super().__init__(scale_factor, in_channels, out_channels, dim=3, mode=mode) 582 583 584class AnisotropicUNet(UNetBase): 585 """A 3D U-Net network for segmentation and other image-to-image tasks. 586 587 The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. 588 The number of levels is determined by the length of the scale_factors argument. 589 The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling 590 is done in an anisotropic fashion, i.e. only across the xy-plane, 591 by specifying [2, 2, 2] it is done in an isotropic fashion. 592 593 By default the U-Net uses two convolutional layers per level. 594 This can be changed by providing an argument for `conv_block_impl`. 595 596 Args: 597 in_channels: The number of input image channels. 598 out_channels: The number of output image channels. 599 scale_factors: The factors for max pooling for the levels of the U-Net. 600 initial_features: The initial number of features, corresponding to the features of the first conv block. 601 gain: The gain factor for increasing the features after each level. 602 final_activation: The activation applied after the output convolution or last decoder layer. 603 return_side_outputs: Whether to return the outputs after each decoder level. 604 conv_block_impl: The implementation of the convolutional block. 605 anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor. 606 postprocessing: A postprocessing function to apply after the U-Net output. 607 check_shape: Whether to check the input shape to the U-Net forward call. 608 conv_block_kwargs: The keyword arguments for the convolutional block. 609 """ 610 def __init__( 611 self, 612 in_channels: int, 613 out_channels: int, 614 scale_factors: List[List[int]], 615 initial_features: int = 32, 616 gain: int = 2, 617 final_activation: Optional[Union[str, nn.Module]] = None, 618 return_side_outputs: bool = False, 619 conv_block_impl: nn.Module = ConvBlock3d, 620 anisotropic_kernel: bool = False, 621 postprocessing: Optional[Union[str, nn.Module]] = None, 622 check_shape: bool = True, 623 **conv_block_kwargs, 624 ): 625 depth = len(scale_factors) 626 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 627 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 628 629 if return_side_outputs: 630 if isinstance(out_channels, int) or out_channels is None: 631 out_channels = [out_channels] * depth 632 if len(out_channels) != depth: 633 raise ValueError() 634 out_conv = nn.ModuleList( 635 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 636 ) 637 else: 638 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 639 640 super().__init__( 641 encoder=Encoder( 642 features=features_encoder, 643 scale_factors=scale_factors, 644 conv_block_impl=conv_block_impl, 645 pooler_impl=nn.MaxPool3d, 646 anisotropic_kernel=anisotropic_kernel, 647 **conv_block_kwargs 648 ), 649 decoder=Decoder( 650 features=features_decoder, 651 scale_factors=scale_factors[::-1], 652 conv_block_impl=conv_block_impl, 653 sampler_impl=Upsampler3d, 654 anisotropic_kernel=anisotropic_kernel, 655 **conv_block_kwargs 656 ), 657 base=conv_block_impl( 658 features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs 659 ), 660 out_conv=out_conv, 661 final_activation=final_activation, 662 postprocessing=postprocessing, 663 check_shape=check_shape, 664 ) 665 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 666 "initial_features": initial_features, "gain": gain, 667 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 668 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 669 "postprocessing": postprocessing, **conv_block_kwargs} 670 671 def _check_shape(self, x): 672 spatial_shape = tuple(x.shape)[2:] 673 scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder)) 674 factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 675 if len(spatial_shape) != len(factor): 676 msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}" 677 raise ValueError(msg) 678 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 679 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 680 raise ValueError(msg) 681 682 683class UNet3d(AnisotropicUNet): 684 """A 3D U-Net network for segmentation and other image-to-image tasks. 685 686 This class uses the same implementation as `AnisotropicUNet`, with isotropic scaling in each level. 687 688 Args: 689 in_channels: The number of input image channels. 690 out_channels: The number of output image channels. 691 depth: The number of encoder / decoder levels of the U-Net. 692 initial_features: The initial number of features, corresponding to the features of the first conv block. 693 gain: The gain factor for increasing the features after each level. 694 final_activation: The activation applied after the output convolution or last decoder layer. 695 return_side_outputs: Whether to return the outputs after each decoder level. 696 conv_block_impl: The implementation of the convolutional block. 697 postprocessing: A postprocessing function to apply after the U-Net output. 698 check_shape: Whether to check the input shape to the U-Net forward call. 699 conv_block_kwargs: The keyword arguments for the convolutional block. 700 """ 701 def __init__( 702 self, 703 in_channels: int, 704 out_channels: int, 705 depth: int = 4, 706 initial_features: int = 32, 707 gain: int = 2, 708 final_activation: Optional[Union[str, nn.Module]] = None, 709 return_side_outputs: bool = False, 710 conv_block_impl: nn.Module = ConvBlock3d, 711 postprocessing: Optional[Union[str, nn.Module]] = None, 712 check_shape: bool = True, 713 **conv_block_kwargs, 714 ): 715 scale_factors = depth * [2] 716 super().__init__(in_channels, out_channels, scale_factors, 717 initial_features=initial_features, gain=gain, 718 final_activation=final_activation, 719 return_side_outputs=return_side_outputs, 720 anisotropic_kernel=False, 721 postprocessing=postprocessing, 722 conv_block_impl=conv_block_impl, 723 check_shape=check_shape, 724 **conv_block_kwargs) 725 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 726 "initial_features": initial_features, "gain": gain, 727 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 728 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
105class UNetBase(nn.Module): 106 """Base class for implementing a U-Net. 107 108 Args: 109 encoder: The encoder of the U-Net. 110 base: The base layer of the U-Net. 111 decoder: The decoder of the U-Net. 112 out_conv: The output convolution applied after the last decoder layer. 113 final_activation: The activation applied after the output convolution or last decoder layer. 114 postprocessing: A postprocessing function to apply after the U-Net output. 115 check_shape: Whether to check the input shape to the U-Net forward call. 116 """ 117 def __init__( 118 self, 119 encoder: nn.Module, 120 base: nn.Module, 121 decoder: nn.Module, 122 out_conv: Optional[nn.Module] = None, 123 final_activation: Optional[Union[nn.Module, str]] = None, 124 postprocessing: Optional[Union[nn.Module, str]] = None, 125 check_shape: bool = True, 126 ): 127 super().__init__() 128 if len(encoder) != len(decoder): 129 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 130 131 self.encoder = encoder 132 self.base = base 133 self.decoder = decoder 134 135 if out_conv is None: 136 self.return_decoder_outputs = False 137 self._out_channels = self.decoder.out_channels 138 elif isinstance(out_conv, nn.ModuleList): 139 if len(out_conv) != len(self.decoder): 140 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 141 self.return_decoder_outputs = True 142 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 143 else: 144 self.return_decoder_outputs = False 145 self._out_channels = out_conv.out_channels 146 self.out_conv = out_conv 147 self.check_shape = check_shape 148 self.final_activation = self._get_activation(final_activation) 149 self.postprocessing = self._get_postprocessing(postprocessing) 150 151 @property 152 def in_channels(self): 153 return self.encoder.in_channels 154 155 @property 156 def out_channels(self): 157 return self._out_channels 158 159 @property 160 def depth(self): 161 return len(self.encoder) 162 163 def _get_activation(self, activation): 164 return_activation = None 165 if activation is None: 166 return None 167 if isinstance(activation, nn.Module): 168 return activation 169 if isinstance(activation, str): 170 return_activation = getattr(nn, activation, None) 171 if return_activation is None: 172 raise ValueError(f"Invalid activation: {activation}") 173 return return_activation() 174 175 def _get_postprocessing(self, postprocessing): 176 if postprocessing is None: 177 return None 178 elif isinstance(postprocessing, nn.Module): 179 return postprocessing 180 elif postprocessing in POSTPROCESSING: 181 return POSTPROCESSING[postprocessing]() 182 else: 183 raise ValueError(f"Invalid postprocessing: {postprocessing}") 184 185 # load encoder / decoder / base states for pretraining 186 def load_encoder_state(self, state): 187 self.encoder.load_state_dict(state) 188 189 def load_decoder_state(self, state): 190 self.decoder.load_state_dict(state) 191 192 def load_base_state(self, state): 193 self.base.load_state_dict(state) 194 195 def _apply_default(self, x): 196 self.encoder.return_outputs = True 197 self.decoder.return_outputs = False 198 199 x, encoder_out = self.encoder(x) 200 x = self.base(x) 201 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 202 203 if self.out_conv is not None: 204 x = self.out_conv(x) 205 if self.final_activation is not None: 206 x = self.final_activation(x) 207 if self.postprocessing is not None: 208 x = self.postprocessing(x) 209 210 return x 211 212 def _apply_with_side_outputs(self, x): 213 self.encoder.return_outputs = True 214 self.decoder.return_outputs = True 215 216 x, encoder_out = self.encoder(x) 217 x = self.base(x) 218 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 219 220 x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)] 221 if self.final_activation is not None: 222 x = [self.final_activation(xx) for xx in x] 223 224 if self.postprocessing is not None: 225 x = [self.postprocessing(xx) for xx in x] 226 227 # we reverse the list to have the full shape output as first element 228 return x[::-1] 229 230 def _check_shape(self, x): 231 spatial_shape = tuple(x.shape)[2:] 232 depth = len(self.encoder) 233 factor = [2**depth] * len(spatial_shape) 234 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 235 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 236 raise ValueError(msg) 237 238 def forward(self, x: torch.Tensor) -> torch.tensor: 239 """Apply U-Net to input data. 240 241 Args: 242 x: The input data. 243 244 Returns: 245 The output of the U-Net. 246 """ 247 # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference. 248 # x = x.float() 249 if getattr(self, "check_shape", True): 250 self._check_shape(x) 251 if self.return_decoder_outputs: 252 return self._apply_with_side_outputs(x) 253 else: 254 return self._apply_default(x)
Base class for implementing a U-Net.
Arguments:
- encoder: The encoder of the U-Net.
- base: The base layer of the U-Net.
- decoder: The decoder of the U-Net.
- out_conv: The output convolution applied after the last decoder layer.
- final_activation: The activation applied after the output convolution or last decoder layer.
- postprocessing: A postprocessing function to apply after the U-Net output.
- check_shape: Whether to check the input shape to the U-Net forward call.
117 def __init__( 118 self, 119 encoder: nn.Module, 120 base: nn.Module, 121 decoder: nn.Module, 122 out_conv: Optional[nn.Module] = None, 123 final_activation: Optional[Union[nn.Module, str]] = None, 124 postprocessing: Optional[Union[nn.Module, str]] = None, 125 check_shape: bool = True, 126 ): 127 super().__init__() 128 if len(encoder) != len(decoder): 129 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 130 131 self.encoder = encoder 132 self.base = base 133 self.decoder = decoder 134 135 if out_conv is None: 136 self.return_decoder_outputs = False 137 self._out_channels = self.decoder.out_channels 138 elif isinstance(out_conv, nn.ModuleList): 139 if len(out_conv) != len(self.decoder): 140 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 141 self.return_decoder_outputs = True 142 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 143 else: 144 self.return_decoder_outputs = False 145 self._out_channels = out_conv.out_channels 146 self.out_conv = out_conv 147 self.check_shape = check_shape 148 self.final_activation = self._get_activation(final_activation) 149 self.postprocessing = self._get_postprocessing(postprocessing)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
238 def forward(self, x: torch.Tensor) -> torch.tensor: 239 """Apply U-Net to input data. 240 241 Args: 242 x: The input data. 243 244 Returns: 245 The output of the U-Net. 246 """ 247 # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference. 248 # x = x.float() 249 if getattr(self, "check_shape", True): 250 self._check_shape(x) 251 if self.return_decoder_outputs: 252 return self._apply_with_side_outputs(x) 253 else: 254 return self._apply_default(x)
Apply U-Net to input data.
Arguments:
- x: The input data.
Returns:
The output of the U-Net.
482class UNet2d(UNetBase): 483 """A 2D U-Net network for segmentation and other image-to-image tasks. 484 485 The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. 486 The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers 487 per level, max-pooling for downsampling and linear interpolation for upsampling. 488 These implementations can be changed by providing arguments for `conv_block_impl`, `pooler_impl` 489 and `sampler_impl` respectively. 490 491 Args: 492 in_channels: The number of input image channels. 493 out_channels: The number of output image channels. 494 depth: The number of encoder / decoder levels of the U-Net. 495 initial_features: The initial number of features, corresponding to the features of the first conv block. 496 gain: The gain factor for increasing the features after each level. 497 final_activation: The activation applied after the output convolution or last decoder layer. 498 return_side_outputs: Whether to return the outputs after each decoder level. 499 conv_block_impl: The implementation of the convolutional block. 500 pooler_impl: The implementation of the pooling layer. 501 postprocessing: A postprocessing function to apply after the U-Net output. 502 check_shape: Whether to check the input shape to the U-Net forward call. 503 conv_block_kwargs: The keyword arguments for the convolutional block. 504 """ 505 def __init__( 506 self, 507 in_channels: int, 508 out_channels: int, 509 depth: int = 4, 510 initial_features: int = 32, 511 gain: int = 2, 512 final_activation=None, 513 return_side_outputs: bool = False, 514 conv_block_impl: nn.Module = ConvBlock2d, 515 pooler_impl: nn.Module = nn.MaxPool2d, 516 sampler_impl: nn.Module = Upsampler2d, 517 postprocessing: Optional[Union[nn.Module, str]] = None, 518 check_shape: bool = True, 519 **conv_block_kwargs, 520 ): 521 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 522 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 523 scale_factors = depth * [2] 524 525 if return_side_outputs: 526 if isinstance(out_channels, int) or out_channels is None: 527 out_channels = [out_channels] * depth 528 if len(out_channels) != depth: 529 raise ValueError() 530 out_conv = nn.ModuleList( 531 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 532 ) 533 else: 534 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 535 536 super().__init__( 537 encoder=Encoder( 538 features=features_encoder, 539 scale_factors=scale_factors, 540 conv_block_impl=conv_block_impl, 541 pooler_impl=pooler_impl, 542 **conv_block_kwargs 543 ), 544 decoder=Decoder( 545 features=features_decoder, 546 scale_factors=scale_factors[::-1], 547 conv_block_impl=conv_block_impl, 548 sampler_impl=sampler_impl, 549 **conv_block_kwargs 550 ), 551 base=conv_block_impl( 552 features_encoder[-1], features_encoder[-1] * gain, 553 **conv_block_kwargs 554 ), 555 out_conv=out_conv, 556 final_activation=final_activation, 557 postprocessing=postprocessing, 558 check_shape=check_shape, 559 ) 560 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 561 "initial_features": initial_features, "gain": gain, 562 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 563 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 564 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
A 2D U-Net network for segmentation and other image-to-image tasks.
The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level.
The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers
per level, max-pooling for downsampling and linear interpolation for upsampling.
These implementations can be changed by providing arguments for conv_block_impl
, pooler_impl
and sampler_impl
respectively.
Arguments:
- in_channels: The number of input image channels.
- out_channels: The number of output image channels.
- depth: The number of encoder / decoder levels of the U-Net.
- initial_features: The initial number of features, corresponding to the features of the first conv block.
- gain: The gain factor for increasing the features after each level.
- final_activation: The activation applied after the output convolution or last decoder layer.
- return_side_outputs: Whether to return the outputs after each decoder level.
- conv_block_impl: The implementation of the convolutional block.
- pooler_impl: The implementation of the pooling layer.
- postprocessing: A postprocessing function to apply after the U-Net output.
- check_shape: Whether to check the input shape to the U-Net forward call.
- conv_block_kwargs: The keyword arguments for the convolutional block.
505 def __init__( 506 self, 507 in_channels: int, 508 out_channels: int, 509 depth: int = 4, 510 initial_features: int = 32, 511 gain: int = 2, 512 final_activation=None, 513 return_side_outputs: bool = False, 514 conv_block_impl: nn.Module = ConvBlock2d, 515 pooler_impl: nn.Module = nn.MaxPool2d, 516 sampler_impl: nn.Module = Upsampler2d, 517 postprocessing: Optional[Union[nn.Module, str]] = None, 518 check_shape: bool = True, 519 **conv_block_kwargs, 520 ): 521 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 522 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 523 scale_factors = depth * [2] 524 525 if return_side_outputs: 526 if isinstance(out_channels, int) or out_channels is None: 527 out_channels = [out_channels] * depth 528 if len(out_channels) != depth: 529 raise ValueError() 530 out_conv = nn.ModuleList( 531 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 532 ) 533 else: 534 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 535 536 super().__init__( 537 encoder=Encoder( 538 features=features_encoder, 539 scale_factors=scale_factors, 540 conv_block_impl=conv_block_impl, 541 pooler_impl=pooler_impl, 542 **conv_block_kwargs 543 ), 544 decoder=Decoder( 545 features=features_decoder, 546 scale_factors=scale_factors[::-1], 547 conv_block_impl=conv_block_impl, 548 sampler_impl=sampler_impl, 549 **conv_block_kwargs 550 ), 551 base=conv_block_impl( 552 features_encoder[-1], features_encoder[-1] * gain, 553 **conv_block_kwargs 554 ), 555 out_conv=out_conv, 556 final_activation=final_activation, 557 postprocessing=postprocessing, 558 check_shape=check_shape, 559 ) 560 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 561 "initial_features": initial_features, "gain": gain, 562 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 563 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 564 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
585class AnisotropicUNet(UNetBase): 586 """A 3D U-Net network for segmentation and other image-to-image tasks. 587 588 The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. 589 The number of levels is determined by the length of the scale_factors argument. 590 The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling 591 is done in an anisotropic fashion, i.e. only across the xy-plane, 592 by specifying [2, 2, 2] it is done in an isotropic fashion. 593 594 By default the U-Net uses two convolutional layers per level. 595 This can be changed by providing an argument for `conv_block_impl`. 596 597 Args: 598 in_channels: The number of input image channels. 599 out_channels: The number of output image channels. 600 scale_factors: The factors for max pooling for the levels of the U-Net. 601 initial_features: The initial number of features, corresponding to the features of the first conv block. 602 gain: The gain factor for increasing the features after each level. 603 final_activation: The activation applied after the output convolution or last decoder layer. 604 return_side_outputs: Whether to return the outputs after each decoder level. 605 conv_block_impl: The implementation of the convolutional block. 606 anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor. 607 postprocessing: A postprocessing function to apply after the U-Net output. 608 check_shape: Whether to check the input shape to the U-Net forward call. 609 conv_block_kwargs: The keyword arguments for the convolutional block. 610 """ 611 def __init__( 612 self, 613 in_channels: int, 614 out_channels: int, 615 scale_factors: List[List[int]], 616 initial_features: int = 32, 617 gain: int = 2, 618 final_activation: Optional[Union[str, nn.Module]] = None, 619 return_side_outputs: bool = False, 620 conv_block_impl: nn.Module = ConvBlock3d, 621 anisotropic_kernel: bool = False, 622 postprocessing: Optional[Union[str, nn.Module]] = None, 623 check_shape: bool = True, 624 **conv_block_kwargs, 625 ): 626 depth = len(scale_factors) 627 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 628 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 629 630 if return_side_outputs: 631 if isinstance(out_channels, int) or out_channels is None: 632 out_channels = [out_channels] * depth 633 if len(out_channels) != depth: 634 raise ValueError() 635 out_conv = nn.ModuleList( 636 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 637 ) 638 else: 639 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 640 641 super().__init__( 642 encoder=Encoder( 643 features=features_encoder, 644 scale_factors=scale_factors, 645 conv_block_impl=conv_block_impl, 646 pooler_impl=nn.MaxPool3d, 647 anisotropic_kernel=anisotropic_kernel, 648 **conv_block_kwargs 649 ), 650 decoder=Decoder( 651 features=features_decoder, 652 scale_factors=scale_factors[::-1], 653 conv_block_impl=conv_block_impl, 654 sampler_impl=Upsampler3d, 655 anisotropic_kernel=anisotropic_kernel, 656 **conv_block_kwargs 657 ), 658 base=conv_block_impl( 659 features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs 660 ), 661 out_conv=out_conv, 662 final_activation=final_activation, 663 postprocessing=postprocessing, 664 check_shape=check_shape, 665 ) 666 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 667 "initial_features": initial_features, "gain": gain, 668 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 669 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 670 "postprocessing": postprocessing, **conv_block_kwargs} 671 672 def _check_shape(self, x): 673 spatial_shape = tuple(x.shape)[2:] 674 scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder)) 675 factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 676 if len(spatial_shape) != len(factor): 677 msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}" 678 raise ValueError(msg) 679 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 680 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 681 raise ValueError(msg)
A 3D U-Net network for segmentation and other image-to-image tasks.
The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. The number of levels is determined by the length of the scale_factors argument. The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling is done in an anisotropic fashion, i.e. only across the xy-plane, by specifying [2, 2, 2] it is done in an isotropic fashion.
By default the U-Net uses two convolutional layers per level.
This can be changed by providing an argument for conv_block_impl
.
Arguments:
- in_channels: The number of input image channels.
- out_channels: The number of output image channels.
- scale_factors: The factors for max pooling for the levels of the U-Net.
- initial_features: The initial number of features, corresponding to the features of the first conv block.
- gain: The gain factor for increasing the features after each level.
- final_activation: The activation applied after the output convolution or last decoder layer.
- return_side_outputs: Whether to return the outputs after each decoder level.
- conv_block_impl: The implementation of the convolutional block.
- anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor.
- postprocessing: A postprocessing function to apply after the U-Net output.
- check_shape: Whether to check the input shape to the U-Net forward call.
- conv_block_kwargs: The keyword arguments for the convolutional block.
611 def __init__( 612 self, 613 in_channels: int, 614 out_channels: int, 615 scale_factors: List[List[int]], 616 initial_features: int = 32, 617 gain: int = 2, 618 final_activation: Optional[Union[str, nn.Module]] = None, 619 return_side_outputs: bool = False, 620 conv_block_impl: nn.Module = ConvBlock3d, 621 anisotropic_kernel: bool = False, 622 postprocessing: Optional[Union[str, nn.Module]] = None, 623 check_shape: bool = True, 624 **conv_block_kwargs, 625 ): 626 depth = len(scale_factors) 627 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 628 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 629 630 if return_side_outputs: 631 if isinstance(out_channels, int) or out_channels is None: 632 out_channels = [out_channels] * depth 633 if len(out_channels) != depth: 634 raise ValueError() 635 out_conv = nn.ModuleList( 636 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 637 ) 638 else: 639 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 640 641 super().__init__( 642 encoder=Encoder( 643 features=features_encoder, 644 scale_factors=scale_factors, 645 conv_block_impl=conv_block_impl, 646 pooler_impl=nn.MaxPool3d, 647 anisotropic_kernel=anisotropic_kernel, 648 **conv_block_kwargs 649 ), 650 decoder=Decoder( 651 features=features_decoder, 652 scale_factors=scale_factors[::-1], 653 conv_block_impl=conv_block_impl, 654 sampler_impl=Upsampler3d, 655 anisotropic_kernel=anisotropic_kernel, 656 **conv_block_kwargs 657 ), 658 base=conv_block_impl( 659 features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs 660 ), 661 out_conv=out_conv, 662 final_activation=final_activation, 663 postprocessing=postprocessing, 664 check_shape=check_shape, 665 ) 666 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 667 "initial_features": initial_features, "gain": gain, 668 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 669 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 670 "postprocessing": postprocessing, **conv_block_kwargs}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
684class UNet3d(AnisotropicUNet): 685 """A 3D U-Net network for segmentation and other image-to-image tasks. 686 687 This class uses the same implementation as `AnisotropicUNet`, with isotropic scaling in each level. 688 689 Args: 690 in_channels: The number of input image channels. 691 out_channels: The number of output image channels. 692 depth: The number of encoder / decoder levels of the U-Net. 693 initial_features: The initial number of features, corresponding to the features of the first conv block. 694 gain: The gain factor for increasing the features after each level. 695 final_activation: The activation applied after the output convolution or last decoder layer. 696 return_side_outputs: Whether to return the outputs after each decoder level. 697 conv_block_impl: The implementation of the convolutional block. 698 postprocessing: A postprocessing function to apply after the U-Net output. 699 check_shape: Whether to check the input shape to the U-Net forward call. 700 conv_block_kwargs: The keyword arguments for the convolutional block. 701 """ 702 def __init__( 703 self, 704 in_channels: int, 705 out_channels: int, 706 depth: int = 4, 707 initial_features: int = 32, 708 gain: int = 2, 709 final_activation: Optional[Union[str, nn.Module]] = None, 710 return_side_outputs: bool = False, 711 conv_block_impl: nn.Module = ConvBlock3d, 712 postprocessing: Optional[Union[str, nn.Module]] = None, 713 check_shape: bool = True, 714 **conv_block_kwargs, 715 ): 716 scale_factors = depth * [2] 717 super().__init__(in_channels, out_channels, scale_factors, 718 initial_features=initial_features, gain=gain, 719 final_activation=final_activation, 720 return_side_outputs=return_side_outputs, 721 anisotropic_kernel=False, 722 postprocessing=postprocessing, 723 conv_block_impl=conv_block_impl, 724 check_shape=check_shape, 725 **conv_block_kwargs) 726 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 727 "initial_features": initial_features, "gain": gain, 728 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 729 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
A 3D U-Net network for segmentation and other image-to-image tasks.
This class uses the same implementation as AnisotropicUNet
, with isotropic scaling in each level.
Arguments:
- in_channels: The number of input image channels.
- out_channels: The number of output image channels.
- depth: The number of encoder / decoder levels of the U-Net.
- initial_features: The initial number of features, corresponding to the features of the first conv block.
- gain: The gain factor for increasing the features after each level.
- final_activation: The activation applied after the output convolution or last decoder layer.
- return_side_outputs: Whether to return the outputs after each decoder level.
- conv_block_impl: The implementation of the convolutional block.
- postprocessing: A postprocessing function to apply after the U-Net output.
- check_shape: Whether to check the input shape to the U-Net forward call.
- conv_block_kwargs: The keyword arguments for the convolutional block.
702 def __init__( 703 self, 704 in_channels: int, 705 out_channels: int, 706 depth: int = 4, 707 initial_features: int = 32, 708 gain: int = 2, 709 final_activation: Optional[Union[str, nn.Module]] = None, 710 return_side_outputs: bool = False, 711 conv_block_impl: nn.Module = ConvBlock3d, 712 postprocessing: Optional[Union[str, nn.Module]] = None, 713 check_shape: bool = True, 714 **conv_block_kwargs, 715 ): 716 scale_factors = depth * [2] 717 super().__init__(in_channels, out_channels, scale_factors, 718 initial_features=initial_features, gain=gain, 719 final_activation=final_activation, 720 return_side_outputs=return_side_outputs, 721 anisotropic_kernel=False, 722 postprocessing=postprocessing, 723 conv_block_impl=conv_block_impl, 724 check_shape=check_shape, 725 **conv_block_kwargs) 726 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 727 "initial_features": initial_features, "gain": gain, 728 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 729 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
Initialize internal Module state, shared by both nn.Module and ScriptModule.