torch_em.model.resnet3d

  1# This file implements 3d resnets, based on the implementations from torchvision:
  2# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
  3
  4from typing import Any, Callable, List, Optional, Type, Union
  5
  6import torch
  7import torch.nn as nn
  8from torch import Tensor
  9
 10# from torchvision.models._api import WeightsEnum
 11from torchvision.models._utils import _ovewrite_named_param
 12# from torchvision.utils import _log_api_usage_once
 13
 14
 15__all__ = [
 16    "ResNet3d",
 17    "resnet3d_18",
 18    "resnet3d_34",
 19    "resnet3d_50",
 20    "resnet3d_101",
 21    "resnet3d_152",
 22    "resnext3d_50_32x4d",
 23    "resnext3d_101_32x8d",
 24    "resnext3d_101_64x4d",
 25    "wide_resnet3d_50_2",
 26    "wide_resnet3d_101_2",
 27]
 28
 29
 30def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv3d:
 31    """@private
 32    """
 33    # 3x3 convolution with padding
 34    return nn.Conv3d(
 35        in_planes,
 36        out_planes,
 37        kernel_size=3,
 38        stride=stride,
 39        padding=dilation,
 40        groups=groups,
 41        bias=False,
 42        dilation=dilation,
 43    )
 44
 45
 46def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
 47    """@private
 48    """
 49    # 1x1 convolution
 50    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 51
 52
 53class BasicBlock(nn.Module):
 54    """@private
 55    """
 56    expansion: int = 1
 57
 58    def __init__(
 59        self,
 60        inplanes: int,
 61        planes: int,
 62        stride: int = 1,
 63        downsample: Optional[nn.Module] = None,
 64        groups: int = 1,
 65        base_width: int = 64,
 66        dilation: int = 1,
 67        norm_layer: Optional[Callable[..., nn.Module]] = None,
 68    ) -> None:
 69        super().__init__()
 70        if norm_layer is None:
 71            norm_layer = nn.BatchNorm3d
 72        if groups != 1 or base_width != 64:
 73            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
 74        if dilation > 1:
 75            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
 76        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
 77        self.conv1 = conv3x3(inplanes, planes, stride)
 78        self.bn1 = norm_layer(planes)
 79        self.relu = nn.ReLU(inplace=True)
 80        self.conv2 = conv3x3(planes, planes)
 81        self.bn2 = norm_layer(planes)
 82        self.downsample = downsample
 83        self.stride = stride
 84
 85    def forward(self, x: Tensor) -> Tensor:
 86        identity = x
 87
 88        out = self.conv1(x)
 89        out = self.bn1(out)
 90        out = self.relu(out)
 91
 92        out = self.conv2(out)
 93        out = self.bn2(out)
 94
 95        if self.downsample is not None:
 96            identity = self.downsample(x)
 97
 98        out += identity
 99        out = self.relu(out)
100
101        return out
102
103
104class Bottleneck(nn.Module):
105    """@private
106    """
107    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
108    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
109    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
110    # This variant is also known as ResNet V1.5 and improves accuracy according to
111    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
112
113    expansion: int = 4
114
115    def __init__(
116        self,
117        inplanes: int,
118        planes: int,
119        stride: int = 1,
120        downsample: Optional[nn.Module] = None,
121        groups: int = 1,
122        base_width: int = 64,
123        dilation: int = 1,
124        norm_layer: Optional[Callable[..., nn.Module]] = None,
125    ) -> None:
126        super().__init__()
127        if norm_layer is None:
128            norm_layer = nn.BatchNorm3d
129        width = int(planes * (base_width / 64.0)) * groups
130        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
131        self.conv1 = conv1x1(inplanes, width)
132        self.bn1 = norm_layer(width)
133        self.conv2 = conv3x3(width, width, stride, groups, dilation)
134        self.bn2 = norm_layer(width)
135        self.conv3 = conv1x1(width, planes * self.expansion)
136        self.bn3 = norm_layer(planes * self.expansion)
137        self.relu = nn.ReLU(inplace=True)
138        self.downsample = downsample
139        self.stride = stride
140
141    def forward(self, x: Tensor) -> Tensor:
142        identity = x
143
144        out = self.conv1(x)
145        out = self.bn1(out)
146        out = self.relu(out)
147
148        out = self.conv2(out)
149        out = self.bn2(out)
150        out = self.relu(out)
151
152        out = self.conv3(out)
153        out = self.bn3(out)
154
155        if self.downsample is not None:
156            identity = self.downsample(x)
157
158        out += identity
159        out = self.relu(out)
160
161        return out
162
163
164class ResNet3d(nn.Module):
165    """@private
166    """
167    def __init__(
168        self,
169        block: Type[Union[BasicBlock, Bottleneck]],
170        layers: List[int],
171        in_channels: int,
172        out_channels: int,
173        zero_init_residual: bool = False,
174        groups: int = 1,
175        width_per_group: int = 64,
176        replace_stride_with_dilation: Optional[List[bool]] = None,
177        norm_layer: Optional[Callable[..., nn.Module]] = None,
178        stride_conv1: bool = True,
179    ) -> None:
180        super().__init__()
181        # _log_api_usage_once(self)
182        if norm_layer is None:
183            norm_layer = nn.BatchNorm3d
184        self._norm_layer = norm_layer
185
186        self.in_channels = in_channels
187        self.out_channels = out_channels
188
189        self.inplanes = 64
190        self.dilation = 1
191        if replace_stride_with_dilation is None:
192            # each element in the tuple indicates if we should replace
193            # the 2x2 stride with a dilated convolution instead
194            replace_stride_with_dilation = [False, False, False]
195        if len(replace_stride_with_dilation) != 3:
196            raise ValueError(
197                "replace_stride_with_dilation should be None "
198                f"or a 3-element tuple, got {replace_stride_with_dilation}"
199            )
200        self.groups = groups
201        self.base_width = width_per_group
202        self.conv1 = nn.Conv3d(
203            in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False
204        )
205        self.bn1 = norm_layer(self.inplanes)
206        self.relu = nn.ReLU(inplace=True)
207        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
208        self.layer1 = self._make_layer(block, 64, layers[0])
209        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
210        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
211        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
212        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
213        self.fc = nn.Linear(512 * block.expansion, out_channels)
214
215        for m in self.modules():
216            if isinstance(m, nn.Conv3d):
217                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
218            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
219                nn.init.constant_(m.weight, 1)
220                nn.init.constant_(m.bias, 0)
221
222        # Zero-initialize the last BN in each residual branch,
223        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
224        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
225        if zero_init_residual:
226            for m in self.modules():
227                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
228                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
229                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
230                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
231
232    def _make_layer(
233        self,
234        block: Type[Union[BasicBlock, Bottleneck]],
235        planes: int,
236        blocks: int,
237        stride: int = 1,
238        dilate: bool = False,
239    ) -> nn.Sequential:
240        norm_layer = self._norm_layer
241        downsample = None
242        previous_dilation = self.dilation
243        if dilate:
244            self.dilation *= stride
245            stride = 1
246        if stride != 1 or self.inplanes != planes * block.expansion:
247            downsample = nn.Sequential(
248                conv1x1(self.inplanes, planes * block.expansion, stride),
249                norm_layer(planes * block.expansion),
250            )
251
252        layers = []
253        layers.append(
254            block(
255                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
256            )
257        )
258        self.inplanes = planes * block.expansion
259        for _ in range(1, blocks):
260            layers.append(
261                block(
262                    self.inplanes,
263                    planes,
264                    groups=self.groups,
265                    base_width=self.base_width,
266                    dilation=self.dilation,
267                    norm_layer=norm_layer,
268                )
269            )
270
271        return nn.Sequential(*layers)
272
273    def _forward_impl(self, x: Tensor) -> Tensor:
274        # See note [TorchScript super()]
275        x = self.conv1(x)
276        x = self.bn1(x)
277        x = self.relu(x)
278        x = self.maxpool(x)
279
280        x = self.layer1(x)
281        x = self.layer2(x)
282        x = self.layer3(x)
283        x = self.layer4(x)
284
285        x = self.avgpool(x)
286        x = torch.flatten(x, 1)
287        x = self.fc(x)
288
289        return x
290
291    def forward(self, x: Tensor) -> Tensor:
292        return self._forward_impl(x)
293
294
295def _resnet(
296    block: Type[Union[BasicBlock, Bottleneck]],
297    layers: List[int],
298    weights: Any,
299    progress: bool,
300    **kwargs: Any,
301) -> ResNet3d:
302    if weights is not None:
303        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
304
305    model = ResNet3d(block, layers, **kwargs)
306
307    if weights is not None:
308        model.load_state_dict(weights.get_state_dict(progress=progress))
309
310    return model
311
312
313def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
314    """Get a residual network for 3d data with 18 layers.
315
316    The implementation of this network is based on torchvision:
317    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
318
319    Args:
320        in_channels: The number of input channels.
321        out_channels: The number of output channels.
322        kwargs: Additional keyword arguments for the ResNet.
323
324    Returns:
325        The 3D ResNet.
326    """
327    return _resnet(
328        BasicBlock, [2, 2, 2, 2], weights=None, progress=False,
329        in_channels=in_channels, out_channels=out_channels, **kwargs
330    )
331
332
333def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
334    """Get a residual network for 3d data with 34 layers.
335
336    The implementation of this network is based on torchvision:
337    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
338
339    Args:
340        in_channels: The number of input channels.
341        out_channels: The number of output channels.
342        kwargs: Additional keyword arguments for the ResNet.
343
344    Returns:
345        The 3D ResNet.
346    """
347    return _resnet(
348        BasicBlock, [3, 4, 6, 3], weights=None, progress=False,
349        in_channels=in_channels, out_channels=out_channels, **kwargs
350    )
351
352
353def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
354    """Get a residual network for 3d data with 50 layers.
355
356    The implementation of this network is based on torchvision:
357    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
358
359    Args:
360        in_channels: The number of input channels.
361        out_channels: The number of output channels.
362        kwargs: Additional keyword arguments for the ResNet.
363
364    Returns:
365        The 3D ResNet.
366    """
367    return _resnet(
368        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
369        in_channels=in_channels, out_channels=out_channels, **kwargs
370    )
371
372
373def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
374    """Get a residual network for 3d data with 101 layers.
375
376    The implementation of this network is based on torchvision:
377    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
378
379    Args:
380        in_channels: The number of input channels.
381        out_channels: The number of output channels.
382        kwargs: Additional keyword arguments for the ResNet.
383
384    Returns:
385        The 3D ResNet.
386    """
387    return _resnet(
388        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
389        in_channels=in_channels, out_channels=out_channels, **kwargs
390    )
391
392
393def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
394    """Get a residual network for 3d data with 152 layers.
395
396    The implementation of this network is based on torchvision:
397    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
398
399    Args:
400        in_channels: The number of input channels.
401        out_channels: The number of output channels.
402        kwargs: Additional keyword arguments for the ResNet.
403
404    Returns:
405        The 3D ResNet.
406    """
407    return _resnet(
408        Bottleneck, [3, 8, 36, 3], weights=None, progress=False,
409        in_channels=in_channels, out_channels=out_channels, **kwargs
410    )
411
412
413def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
414    """Get a residual network for 3d data with 50 layers and ResNext layer design.
415
416    The implementation of this network is based on torchvision:
417    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
418
419    Args:
420        in_channels: The number of input channels.
421        out_channels: The number of output channels.
422        kwargs: Additional keyword arguments for the ResNet.
423
424    Returns:
425        The 3D ResNext.
426    """
427    _ovewrite_named_param(kwargs, "groups", 32)
428    _ovewrite_named_param(kwargs, "width_per_group", 4)
429    return _resnet(
430        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
431        in_channels=in_channels, out_channels=out_channels, **kwargs
432    )
433
434
435def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
436    """Get a residual network for 3d data with 101 layers and ResNext layer design.
437
438    The implementation of this network is based on torchvision:
439    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
440
441    Args:
442        in_channels: The number of input channels.
443        out_channels: The number of output channels.
444        kwargs: Additional keyword arguments for the ResNet.
445
446    Returns:
447        The 3D ResNext.
448    """
449    _ovewrite_named_param(kwargs, "groups", 32)
450    _ovewrite_named_param(kwargs, "width_per_group", 8)
451    return _resnet(
452        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
453        in_channels=in_channels, out_channels=out_channels, **kwargs
454    )
455
456
457def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
458    """Get a residual network for 3d data with 101 layers and ResNext layer design.
459
460    The implementation of this network is based on torchvision:
461    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
462
463    Args:
464        in_channels: The number of input channels.
465        out_channels: The number of output channels.
466        kwargs: Additional keyword arguments for the ResNet.
467
468    Returns:
469        The 3D ResNext.
470    """
471    _ovewrite_named_param(kwargs, "groups", 64)
472    _ovewrite_named_param(kwargs, "width_per_group", 4)
473    return _resnet(
474        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
475        in_channels=in_channels, out_channels=out_channels, **kwargs
476    )
477
478
479def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
480    """Get a wide residual network for 3d data with 50 layers.
481
482    The implementation of this network is based on torchvision:
483    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
484
485    Args:
486        in_channels: The number of input channels.
487        out_channels: The number of output channels.
488        kwargs: Additional keyword arguments for the ResNet.
489
490    Returns:
491        The wide 3D ResNet.
492    """
493    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
494    return _resnet(
495        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
496        in_channels=in_channels, out_channels=out_channels, **kwargs
497    )
498
499
500def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
501    """Get a wide residual network for 3d data with 101 layers.
502
503    The implementation of this network is based on torchvision:
504    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
505
506    Args:
507        in_channels: The number of input channels.
508        out_channels: The number of output channels.
509        kwargs: Additional keyword arguments for the ResNet.
510
511    Returns:
512        The wide 3D ResNet.
513    """
514    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
515    return _resnet(
516        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
517        in_channels=in_channels, out_channels=out_channels, **kwargs
518    )
def resnet3d_18( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
314def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
315    """Get a residual network for 3d data with 18 layers.
316
317    The implementation of this network is based on torchvision:
318    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
319
320    Args:
321        in_channels: The number of input channels.
322        out_channels: The number of output channels.
323        kwargs: Additional keyword arguments for the ResNet.
324
325    Returns:
326        The 3D ResNet.
327    """
328    return _resnet(
329        BasicBlock, [2, 2, 2, 2], weights=None, progress=False,
330        in_channels=in_channels, out_channels=out_channels, **kwargs
331    )

Get a residual network for 3d data with 18 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNet.

def resnet3d_34( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
334def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
335    """Get a residual network for 3d data with 34 layers.
336
337    The implementation of this network is based on torchvision:
338    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
339
340    Args:
341        in_channels: The number of input channels.
342        out_channels: The number of output channels.
343        kwargs: Additional keyword arguments for the ResNet.
344
345    Returns:
346        The 3D ResNet.
347    """
348    return _resnet(
349        BasicBlock, [3, 4, 6, 3], weights=None, progress=False,
350        in_channels=in_channels, out_channels=out_channels, **kwargs
351    )

Get a residual network for 3d data with 34 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNet.

def resnet3d_50( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
354def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
355    """Get a residual network for 3d data with 50 layers.
356
357    The implementation of this network is based on torchvision:
358    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
359
360    Args:
361        in_channels: The number of input channels.
362        out_channels: The number of output channels.
363        kwargs: Additional keyword arguments for the ResNet.
364
365    Returns:
366        The 3D ResNet.
367    """
368    return _resnet(
369        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
370        in_channels=in_channels, out_channels=out_channels, **kwargs
371    )

Get a residual network for 3d data with 50 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNet.

def resnet3d_101( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
374def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
375    """Get a residual network for 3d data with 101 layers.
376
377    The implementation of this network is based on torchvision:
378    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
379
380    Args:
381        in_channels: The number of input channels.
382        out_channels: The number of output channels.
383        kwargs: Additional keyword arguments for the ResNet.
384
385    Returns:
386        The 3D ResNet.
387    """
388    return _resnet(
389        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
390        in_channels=in_channels, out_channels=out_channels, **kwargs
391    )

Get a residual network for 3d data with 101 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNet.

def resnet3d_152( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
394def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
395    """Get a residual network for 3d data with 152 layers.
396
397    The implementation of this network is based on torchvision:
398    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
399
400    Args:
401        in_channels: The number of input channels.
402        out_channels: The number of output channels.
403        kwargs: Additional keyword arguments for the ResNet.
404
405    Returns:
406        The 3D ResNet.
407    """
408    return _resnet(
409        Bottleneck, [3, 8, 36, 3], weights=None, progress=False,
410        in_channels=in_channels, out_channels=out_channels, **kwargs
411    )

Get a residual network for 3d data with 152 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNet.

def resnext3d_50_32x4d( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
414def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
415    """Get a residual network for 3d data with 50 layers and ResNext layer design.
416
417    The implementation of this network is based on torchvision:
418    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
419
420    Args:
421        in_channels: The number of input channels.
422        out_channels: The number of output channels.
423        kwargs: Additional keyword arguments for the ResNet.
424
425    Returns:
426        The 3D ResNext.
427    """
428    _ovewrite_named_param(kwargs, "groups", 32)
429    _ovewrite_named_param(kwargs, "width_per_group", 4)
430    return _resnet(
431        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
432        in_channels=in_channels, out_channels=out_channels, **kwargs
433    )

Get a residual network for 3d data with 50 layers and ResNext layer design.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNext.

def resnext3d_101_32x8d( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
436def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
437    """Get a residual network for 3d data with 101 layers and ResNext layer design.
438
439    The implementation of this network is based on torchvision:
440    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
441
442    Args:
443        in_channels: The number of input channels.
444        out_channels: The number of output channels.
445        kwargs: Additional keyword arguments for the ResNet.
446
447    Returns:
448        The 3D ResNext.
449    """
450    _ovewrite_named_param(kwargs, "groups", 32)
451    _ovewrite_named_param(kwargs, "width_per_group", 8)
452    return _resnet(
453        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
454        in_channels=in_channels, out_channels=out_channels, **kwargs
455    )

Get a residual network for 3d data with 101 layers and ResNext layer design.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNext.

def resnext3d_101_64x4d( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
458def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
459    """Get a residual network for 3d data with 101 layers and ResNext layer design.
460
461    The implementation of this network is based on torchvision:
462    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
463
464    Args:
465        in_channels: The number of input channels.
466        out_channels: The number of output channels.
467        kwargs: Additional keyword arguments for the ResNet.
468
469    Returns:
470        The 3D ResNext.
471    """
472    _ovewrite_named_param(kwargs, "groups", 64)
473    _ovewrite_named_param(kwargs, "width_per_group", 4)
474    return _resnet(
475        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
476        in_channels=in_channels, out_channels=out_channels, **kwargs
477    )

Get a residual network for 3d data with 101 layers and ResNext layer design.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The 3D ResNext.

def wide_resnet3d_50_2( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
480def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
481    """Get a wide residual network for 3d data with 50 layers.
482
483    The implementation of this network is based on torchvision:
484    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
485
486    Args:
487        in_channels: The number of input channels.
488        out_channels: The number of output channels.
489        kwargs: Additional keyword arguments for the ResNet.
490
491    Returns:
492        The wide 3D ResNet.
493    """
494    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
495    return _resnet(
496        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
497        in_channels=in_channels, out_channels=out_channels, **kwargs
498    )

Get a wide residual network for 3d data with 50 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The wide 3D ResNet.

def wide_resnet3d_101_2( in_channels: int, out_channels: int, **kwargs: Any) -> torch_em.model.resnet3d.ResNet3d:
501def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
502    """Get a wide residual network for 3d data with 101 layers.
503
504    The implementation of this network is based on torchvision:
505    https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
506
507    Args:
508        in_channels: The number of input channels.
509        out_channels: The number of output channels.
510        kwargs: Additional keyword arguments for the ResNet.
511
512    Returns:
513        The wide 3D ResNet.
514    """
515    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
516    return _resnet(
517        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
518        in_channels=in_channels, out_channels=out_channels, **kwargs
519    )

Get a wide residual network for 3d data with 101 layers.

The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Arguments:
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kwargs: Additional keyword arguments for the ResNet.
Returns:

The wide 3D ResNet.